Compare commits
83 commits
feat/use-c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 42a0086ba1 | |||
| 673dc60178 | |||
| 029c20c39e | |||
| 136e31c82c | |||
| 2e8ca0ee43 | |||
| 1481a7baac | |||
| 703da9035e | |||
| f6934bdf2a | |||
| ce33aff174 | |||
| 842c4da90c | |||
| 95a576f744 | |||
| 81e3b9a7d0 | |||
| 763407ba1c | |||
| 34f8268cd5 | |||
| 9c73895318 | |||
| 2efc4d1088 | |||
| f6ce97d7fd | |||
| 9e33923f71 | |||
| 65670af78f | |||
| 9cb62d69af | |||
| 4c0746950e | |||
| a418969251 | |||
| fae8c3267f | |||
| d90117807b | |||
| 44c3428993 | |||
| c7dc40c51e | |||
| 39a6c10634 | |||
| 9f793da778 | |||
| 4802e086a0 | |||
| f54f0d317d | |||
| e6fcd5fc54 | |||
| 1c31444611 | |||
| a9e510362d | |||
| 5ee74f367c | |||
| f6cc99f062 | |||
| d0648fe01d | |||
| 5841bc09c0 | |||
| 6d1bc720b4 | |||
| 3c7d607776 | |||
| 4646180942 | |||
| c234b67bbf | |||
| ebefee4184 | |||
| b737ed7b21 | |||
| 322f6b2b1b | |||
| 0f045f814a | |||
| 90e46b707d | |||
| 6183b9c886 | |||
| 050f80dcd7 | |||
| 415e03fba1 | |||
| 406a7ea2fd | |||
| ee023d6e34 | |||
| e46c44f1e0 | |||
| 04a415a191 | |||
| 141153ffa7 | |||
| 8bb220ae43 | |||
| 95728accbf | |||
| dc6d28bda1 | |||
| 1c60c30084 | |||
| a54a968313 | |||
| b109bba873 | |||
| 118d77c428 | |||
| 565d8d0676 | |||
| 83c1996702 | |||
| 132f110463 | |||
| 6d9c239e82 | |||
| abee9cea7b | |||
| acb2d55ce3 | |||
| 81054baa06 | |||
| 632acdcd26 | |||
| 97aa24f478 | |||
| d801038c74 | |||
| 290e51416f | |||
| 2709fb8d6b | |||
| 118a9abd09 | |||
| 1344b9ddb4 | |||
| dcd1bc764a | |||
| b397a80c0b | |||
| 1e340c82fa | |||
| 2d22115893 | |||
| 527fc620fe | |||
| b2ff27c1ca | |||
| 1321d57354 | |||
| 810979e416 |
101 changed files with 13562 additions and 77 deletions
|
|
@ -4,11 +4,11 @@
|
|||
# the Postgres password.
|
||||
|
||||
# --- Job store -----------------------------------------------------------
|
||||
IX_POSTGRES_URL=postgresql+asyncpg://infoxtractor:<password>@host.docker.internal:5431/infoxtractor
|
||||
IX_POSTGRES_URL=postgresql+asyncpg://infoxtractor:<password>@127.0.0.1:5431/infoxtractor
|
||||
|
||||
# --- LLM backend ---------------------------------------------------------
|
||||
IX_OLLAMA_URL=http://host.docker.internal:11434
|
||||
IX_DEFAULT_MODEL=gpt-oss:20b
|
||||
IX_OLLAMA_URL=http://127.0.0.1:11434
|
||||
IX_DEFAULT_MODEL=qwen3:14b
|
||||
|
||||
# --- OCR -----------------------------------------------------------------
|
||||
IX_OCR_ENGINE=surya
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -15,6 +15,7 @@ dist/
|
|||
build/
|
||||
*.log
|
||||
/tmp/
|
||||
.claude/
|
||||
# uv
|
||||
# uv.lock is committed intentionally for reproducible builds.
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,11 @@ Async, on-prem, LLM-powered structured information extraction microservice. Give
|
|||
|
||||
Designed to be used by other on-prem services (e.g. mammon) as a reliable fallback / second opinion for format-specific deterministic parsers.
|
||||
|
||||
Status: design phase. Full reference spec at `docs/spec-core-pipeline.md`. MVP spec will live at `docs/superpowers/specs/`.
|
||||
Status: MVP deployed (2026-04-18) at `http://192.168.68.42:8994` — LAN only. Browser UI at `http://192.168.68.42:8994/ui`. Full reference spec at `docs/spec-core-pipeline.md`; MVP spec at `docs/superpowers/specs/2026-04-18-ix-mvp-design.md`; deploy runbook at `docs/deployment.md`.
|
||||
|
||||
Use cases: the built-in registry lives in `src/ix/use_cases/__init__.py` (`bank_statement_header` for MVP). Callers without a registered entry can ship an ad-hoc schema inline via `RequestIX.use_case_inline` (see README "Ad-hoc use cases"); the pipeline builds the Pydantic classes on the fly per request. The `/ui` page exposes this as a "custom" option so non-engineering users can experiment without a deploy.
|
||||
|
||||
UX notes: the `/ui` job page surfaces queue position + elapsed MM:SS on each poll, renders the client-provided filename (stored via `FileRef.display_name`, optional metadata — the pipeline ignores it for execution), and shows a CPU-mode notice when `/healthz` reports `ocr_gpu: false`. A paginated history lives at `/ui/jobs` (status + client_id filters, newest first).
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
|
|
@ -25,7 +29,7 @@ Status: design phase. Full reference spec at `docs/spec-core-pipeline.md`. MVP s
|
|||
- **Language**: Python 3.12, asyncio
|
||||
- **Web/REST**: FastAPI + uvicorn
|
||||
- **OCR (pluggable)**: Surya OCR first (GPU, shares RTX 3090 with Ollama / Immich ML)
|
||||
- **LLM**: Ollama at `192.168.68.42:11434`, structured outputs via JSON schema. Initial model candidate: `qwen2.5:32b` / `gpt-oss:20b`, configurable per use case
|
||||
- **LLM**: Ollama at `192.168.68.42:11434`, structured outputs via JSON schema. Initial model candidate: `qwen2.5:32b` / `qwen3:14b`, configurable per use case
|
||||
- **State**: Postgres on the shared `postgis` container (:5431), new `infoxtractor` database
|
||||
- **Deployment**: Docker, `git push server main` → post-receive rebuild (pattern from other apps)
|
||||
|
||||
|
|
|
|||
69
Dockerfile
Normal file
69
Dockerfile
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# InfoXtractor container image.
|
||||
#
|
||||
# Base image ships CUDA 12.4 runtime libraries so the Surya OCR client can
|
||||
# use the RTX 3090 on the deploy host. Ubuntu 22.04 is the LTS used across
|
||||
# the home-server stack (immich-ml, monitoring) so GPU drivers line up.
|
||||
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PIP_NO_CACHE_DIR=1 \
|
||||
PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||
|
||||
# --- System deps --------------------------------------------------------
|
||||
# - python3.12 via deadsnakes PPA (pinned; Ubuntu 22.04 ships 3.10 only)
|
||||
# - libmagic1 : python-magic backend for MIME sniffing
|
||||
# - libgl1 : libGL.so needed by Pillow/OpenCV wheels used by Surya
|
||||
# - libglib2.0 : shared by Pillow/PyMuPDF headless rendering
|
||||
# - curl : post-receive hook's /healthz probe & general ops
|
||||
# - ca-certs : httpx TLS verification
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
software-properties-common \
|
||||
ca-certificates \
|
||||
curl \
|
||||
gnupg \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python3.12 \
|
||||
python3.12-venv \
|
||||
python3.12-dev \
|
||||
libmagic1 \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
&& ln -sf /usr/bin/python3.12 /usr/local/bin/python \
|
||||
&& ln -sf /usr/bin/python3.12 /usr/local/bin/python3 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# --- uv (dependency resolver used by the project) -----------------------
|
||||
# Install via the standalone installer; avoids needing a working system pip
|
||||
# (python3.12 on Ubuntu 22.04 has no `distutils`, which breaks Ubuntu pip).
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& ln -sf /root/.local/bin/uv /usr/local/bin/uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency manifests + README early so the heavy `uv sync` layer
|
||||
# caches whenever only application code changes. README.md is required
|
||||
# because pyproject.toml names it as the package's readme — hatchling
|
||||
# validates it exists when resolving the editable install.
|
||||
COPY pyproject.toml uv.lock .python-version README.md ./
|
||||
|
||||
# Prod + OCR extras, no dev tooling. --frozen means "must match uv.lock";
|
||||
# CI catches drift before it reaches the image.
|
||||
RUN uv sync --frozen --no-dev --extra ocr
|
||||
|
||||
# --- Application code ---------------------------------------------------
|
||||
COPY src src
|
||||
COPY alembic alembic
|
||||
COPY alembic.ini ./
|
||||
|
||||
EXPOSE 8994
|
||||
|
||||
# Migrations are idempotent (alembic upgrade head is a no-op on a current
|
||||
# DB) so running them on every start keeps the image + DB aligned without
|
||||
# an extra orchestration step.
|
||||
CMD ["sh", "-c", "uv run alembic upgrade head && uv run uvicorn ix.app:create_app --factory --host 0.0.0.0 --port 8994"]
|
||||
92
README.md
92
README.md
|
|
@ -4,10 +4,18 @@ Async, on-prem, LLM-powered structured information extraction microservice.
|
|||
|
||||
Given a document (PDF, image, text) and a named *use case*, ix returns a structured JSON result whose shape matches the use-case schema — together with per-field provenance (OCR segment IDs, bounding boxes, cross-OCR agreement flags) that let the caller decide how much to trust each extracted value.
|
||||
|
||||
**Status:** design phase. Implementation about to start.
|
||||
**Status:** MVP deployed. Live on the home LAN at `http://192.168.68.42:8994` (REST API + browser UI at `/ui`).
|
||||
|
||||
## Web UI
|
||||
|
||||
A minimal browser UI lives at [`http://192.168.68.42:8994/ui`](http://192.168.68.42:8994/ui): drop a PDF, pick a registered use case or define one inline, submit, see the pretty-printed result. HTMX polls the job status every 2 s until the pipeline finishes. LAN-only, no auth.
|
||||
|
||||
Past submissions are browsable at [`/ui/jobs`](http://192.168.68.42:8994/ui/jobs) — a paginated list (newest first) with status + `client_id` filters. Each row links to `/ui/jobs/{job_id}` for the full request/response view.
|
||||
|
||||
- Full reference spec: [`docs/spec-core-pipeline.md`](docs/spec-core-pipeline.md) (aspirational; MVP is a strict subset)
|
||||
- **MVP design:** [`docs/superpowers/specs/2026-04-18-ix-mvp-design.md`](docs/superpowers/specs/2026-04-18-ix-mvp-design.md)
|
||||
- **Implementation plan:** [`docs/superpowers/plans/2026-04-18-ix-mvp-implementation.md`](docs/superpowers/plans/2026-04-18-ix-mvp-implementation.md)
|
||||
- **Deployment runbook:** [`docs/deployment.md`](docs/deployment.md)
|
||||
- Agent / development notes: [`AGENTS.md`](AGENTS.md)
|
||||
|
||||
## Principles
|
||||
|
|
@ -15,3 +23,85 @@ Given a document (PDF, image, text) and a named *use case*, ix returns a structu
|
|||
- **On-prem always.** LLM = Ollama, OCR = local engines (Surya first). No OpenAI / Anthropic / Azure / AWS / cloud.
|
||||
- **Grounded extraction, not DB truth.** ix returns best-effort fields + provenance; the caller decides what to trust.
|
||||
- **Transport-agnostic pipeline core.** REST + Postgres-queue adapters in parallel on one job store.
|
||||
|
||||
## Submitting a job
|
||||
|
||||
```bash
|
||||
curl -X POST http://192.168.68.42:8994/jobs \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"use_case": "bank_statement_header",
|
||||
"ix_client_id": "mammon",
|
||||
"request_id": "some-correlation-id",
|
||||
"context": {
|
||||
"files": [{
|
||||
"url": "http://paperless.local/api/documents/42/download/",
|
||||
"headers": {"Authorization": "Token …"}
|
||||
}],
|
||||
"texts": ["<Paperless Tesseract OCR content>"]
|
||||
}
|
||||
}'
|
||||
# → {"job_id":"…","ix_id":"…","status":"pending"}
|
||||
```
|
||||
|
||||
Poll `GET /jobs/{job_id}` until `status` is `done` or `error`. Optionally pass `callback_url` to receive a webhook on completion (one-shot, no retry; polling stays authoritative).
|
||||
|
||||
### Ad-hoc use cases
|
||||
|
||||
For one-offs where a registered use case doesn't exist yet, ship the schema inline:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"use_case": "adhoc-invoice", // free-form label (logs/metrics only)
|
||||
"use_case_inline": {
|
||||
"use_case_name": "Invoice totals",
|
||||
"system_prompt": "Extract vendor and total amount.",
|
||||
"fields": [
|
||||
{"name": "vendor", "type": "str", "required": true},
|
||||
{"name": "total", "type": "decimal"},
|
||||
{"name": "currency", "type": "str", "choices": ["USD", "EUR", "CHF"]}
|
||||
]
|
||||
},
|
||||
// ...ix_client_id, request_id, context...
|
||||
}
|
||||
```
|
||||
|
||||
When `use_case_inline` is set, the pipeline builds the response schema on the fly and skips the registry. Supported types: `str`, `int`, `float`, `decimal`, `date`, `datetime`, `bool`. `choices` is only allowed on `str` fields. Precedence: inline wins over `use_case` when both are present.
|
||||
|
||||
Full REST surface + provenance response shape documented in the MVP design spec.
|
||||
|
||||
## Running locally
|
||||
|
||||
```bash
|
||||
uv sync --extra dev
|
||||
uv run pytest tests/unit -v # hermetic unit + integration suite
|
||||
IX_TEST_OLLAMA=1 uv run pytest tests/live -v # needs LAN access to Ollama + GPU
|
||||
```
|
||||
|
||||
### UI jobs list
|
||||
|
||||
`GET /ui/jobs` renders a paginated, newest-first table of submitted jobs. Query params:
|
||||
|
||||
- `status=pending|running|done|error` — repeat for multi-select.
|
||||
- `client_id=<str>` — exact match (e.g. `ui`, `mammon`).
|
||||
- `limit=<n>` (default 50, max 200) + `offset=<n>` for paging.
|
||||
|
||||
Each row shows status badge, original filename (`FileRef.display_name` or URL basename), use case, client id, submitted time + relative, and elapsed wall-clock (terminal rows only). Each row links to `/ui/jobs/{job_id}` for the full response JSON.
|
||||
|
||||
### UI queue + progress UX
|
||||
|
||||
The `/ui` job page polls `GET /ui/jobs/{id}/fragment` every 2 s and surfaces:
|
||||
|
||||
- **Queue position** while pending: "Queue position: N ahead — M jobs total in flight (single worker)" so it's obvious a new submission is waiting on an earlier job rather than stuck. "About to start" when the worker has just freed up.
|
||||
- **Elapsed time** while running ("Running for MM:SS") and on finish ("Finished in MM:SS").
|
||||
- **Original filename** — the UI stashes the client-provided upload name in `FileRef.display_name` so the browser shows `your_statement.pdf` instead of the on-disk UUID.
|
||||
- **CPU-mode notice** when `/healthz` reports `ocr_gpu: false` (the Surya OCR client observed `torch.cuda.is_available() == False`): a collapsed `<details>` pointing at the deployment runbook.
|
||||
|
||||
## Deploying
|
||||
|
||||
```bash
|
||||
git push server main # rebuilds Docker image, restarts container, /healthz deploy gate
|
||||
python scripts/e2e_smoke.py # E2E acceptance against the live service
|
||||
```
|
||||
|
||||
See [`docs/deployment.md`](docs/deployment.md) for full runbook + rollback.
|
||||
|
|
|
|||
47
alembic.ini
Normal file
47
alembic.ini
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
; Alembic configuration for infoxtractor.
|
||||
;
|
||||
; The sqlalchemy.url is filled at runtime from the IX_POSTGRES_URL env var
|
||||
; (alembic/env.py does the substitution). We keep the template here so
|
||||
; ``alembic check`` / ``alembic history`` tools work without an env var set.
|
||||
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
file_template = %%(rev)s_%%(slug)s
|
||||
prepend_sys_path = .
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
89
alembic/env.py
Normal file
89
alembic/env.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Alembic async env — reads ``IX_POSTGRES_URL`` from the environment.
|
||||
|
||||
Mirrors mammon's ``alembic/env.py`` pattern (async engine + ``run_sync`` bridge)
|
||||
so anyone familiar with that repo can read this one without context switch.
|
||||
The only deviations:
|
||||
|
||||
* We source the URL from ``IX_POSTGRES_URL`` via ``os.environ`` rather than via
|
||||
the alembic.ini ``sqlalchemy.url`` setting. Config parsing happens at import
|
||||
time and depending on pydantic-settings here would introduce a cycle with
|
||||
``src/ix/config.py`` (which lands in Task 3.2).
|
||||
* We use ``NullPool`` — migrations open/close their connection once, pooling
|
||||
would hold an unused async connection open after ``alembic upgrade head``
|
||||
returned, which breaks the container's CMD chain.
|
||||
|
||||
Run offline by setting ``-x url=...`` or the env var + ``--sql``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from ix.store.models import Base
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def _database_url() -> str:
|
||||
"""Resolve the connection URL from env, falling back to alembic.ini.
|
||||
|
||||
The env var is the primary source (container CMD sets it). The ini value
|
||||
remains available so ``alembic -x url=...`` or a manual ``alembic.ini``
|
||||
edit still work for one-off scripts.
|
||||
"""
|
||||
|
||||
env_url = os.environ.get("IX_POSTGRES_URL")
|
||||
if env_url:
|
||||
return env_url
|
||||
ini_url = config.get_main_option("sqlalchemy.url")
|
||||
if ini_url and ini_url != "driver://user:pass@localhost/dbname":
|
||||
return ini_url
|
||||
raise RuntimeError(
|
||||
"IX_POSTGRES_URL not set and alembic.ini sqlalchemy.url not configured"
|
||||
)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Emit migrations as SQL without a live connection."""
|
||||
|
||||
context.configure(
|
||||
url=_database_url(),
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection) -> None: # type: ignore[no-untyped-def]
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
engine = create_async_engine(_database_url(), poolclass=NullPool)
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
90
alembic/versions/001_initial_ix_jobs.py
Normal file
90
alembic/versions/001_initial_ix_jobs.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""Initial migration — creates the ``ix_jobs`` table per spec §4.
|
||||
|
||||
Hand-written (do NOT ``alembic revision --autogenerate``) so the table layout
|
||||
stays byte-exact with the MVP spec. autogenerate tends to add/drop indexes in
|
||||
an order that makes diffs noisy and occasionally swaps JSONB for JSON on
|
||||
dialects that don't distinguish them.
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2026-04-18
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "001"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create ``ix_jobs`` + its indexes exactly as spec §4 describes.
|
||||
|
||||
JSONB for ``request`` and ``response`` (Postgres-only; the MVP doesn't
|
||||
support any other backend). CHECK constraint bakes the status enum into
|
||||
the DDL so direct SQL inserts (the pg_queue_adapter path) can't land
|
||||
bogus values. The partial index on ``status='pending'`` matches the
|
||||
claim query's ``WHERE status='pending' ORDER BY created_at`` pattern.
|
||||
"""
|
||||
|
||||
op.create_table(
|
||||
"ix_jobs",
|
||||
sa.Column("job_id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("ix_id", sa.Text(), nullable=False),
|
||||
sa.Column("client_id", sa.Text(), nullable=False),
|
||||
sa.Column("request_id", sa.Text(), nullable=False),
|
||||
sa.Column("status", sa.Text(), nullable=False),
|
||||
sa.Column("request", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("response", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("callback_url", sa.Text(), nullable=True),
|
||||
sa.Column("callback_status", sa.Text(), nullable=True),
|
||||
sa.Column("attempts", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.CheckConstraint(
|
||||
"status IN ('pending', 'running', 'done', 'error')",
|
||||
name="ix_jobs_status_check",
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"callback_status IS NULL OR callback_status IN "
|
||||
"('pending', 'delivered', 'failed')",
|
||||
name="ix_jobs_callback_status_check",
|
||||
),
|
||||
)
|
||||
|
||||
# Partial index: the claim query hits only pending rows ordered by age.
|
||||
# Partial-ness keeps the index small as done/error rows accumulate.
|
||||
op.create_index(
|
||||
"ix_jobs_status_created",
|
||||
"ix_jobs",
|
||||
["status", "created_at"],
|
||||
postgresql_where=sa.text("status = 'pending'"),
|
||||
)
|
||||
# Unique index on (client_id, request_id) enforces caller-side idempotency
|
||||
# at the DB layer. The repo relies on the unique violation to detect an
|
||||
# existing pending/running row and return it unchanged.
|
||||
op.create_index(
|
||||
"ix_jobs_client_request",
|
||||
"ix_jobs",
|
||||
["client_id", "request_id"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_jobs_client_request", table_name="ix_jobs")
|
||||
op.drop_index("ix_jobs_status_created", table_name="ix_jobs")
|
||||
op.drop_table("ix_jobs")
|
||||
42
docker-compose.yml
Normal file
42
docker-compose.yml
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# InfoXtractor Docker Compose stack.
|
||||
#
|
||||
# Single service. Uses host networking so the container can reach:
|
||||
# - Ollama at 127.0.0.1:11434
|
||||
# - postgis at 127.0.0.1:5431 (bound to loopback only; security hardening)
|
||||
# Both services are LAN-hardened on the host and never exposed publicly,
|
||||
# so host-network access stays on-prem. This matches the `goldstein`
|
||||
# container pattern on the same server.
|
||||
#
|
||||
# The GPU reservation block matches immich-ml / the shape Docker Compose
|
||||
# expects for GPU allocation on this host.
|
||||
|
||||
name: infoxtractor
|
||||
|
||||
services:
|
||||
infoxtractor:
|
||||
build: .
|
||||
container_name: infoxtractor
|
||||
network_mode: host
|
||||
restart: always
|
||||
env_file: .env
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
volumes:
|
||||
# Persist Surya (datalab) + HuggingFace model caches so rebuilds don't
|
||||
# re-download ~1.5 GB of weights every time.
|
||||
- ix_surya_cache:/root/.cache/datalab
|
||||
- ix_hf_cache:/root/.cache/huggingface
|
||||
labels:
|
||||
infrastructure.web_url: "http://192.168.68.42:8994"
|
||||
backup.enable: "true"
|
||||
backup.type: "postgres"
|
||||
backup.name: "infoxtractor"
|
||||
|
||||
volumes:
|
||||
ix_surya_cache:
|
||||
ix_hf_cache:
|
||||
153
docs/deployment.md
Normal file
153
docs/deployment.md
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
# Deployment
|
||||
|
||||
On-prem deploy to `192.168.68.42`. Push-to-deploy via a bare git repo + `post-receive` hook that rebuilds the Docker Compose stack. Pattern mirrors mammon and unified_messaging.
|
||||
|
||||
## Topology
|
||||
|
||||
```
|
||||
Mac (dev)
|
||||
│ git push server main
|
||||
▼
|
||||
192.168.68.42:/home/server/Public/infoxtractor/repos.git (bare)
|
||||
│ post-receive → GIT_WORK_TREE=/…/app git checkout -f main
|
||||
│ docker compose up -d --build
|
||||
│ curl /healthz (60 s gate)
|
||||
▼
|
||||
Docker container `infoxtractor` (port 8994)
|
||||
├─ 127.0.0.1:11434 → Ollama (qwen3:14b; host-network mode)
|
||||
└─ 127.0.0.1:5431 → postgis (database `infoxtractor`; host-network mode)
|
||||
```
|
||||
|
||||
## One-time server setup
|
||||
|
||||
Run **once** from the Mac. Idempotent.
|
||||
|
||||
```bash
|
||||
export IX_POSTGRES_PASSWORD=<generate-a-strong-one>
|
||||
./scripts/setup_server.sh
|
||||
```
|
||||
|
||||
The script:
|
||||
1. Creates `/home/server/Public/infoxtractor/repos.git` (bare) + `/home/server/Public/infoxtractor/app/` (worktree).
|
||||
2. Installs the `post-receive` hook (see `scripts/setup_server.sh` for the template).
|
||||
3. Creates the `infoxtractor` Postgres role + database on the shared `postgis` container.
|
||||
4. Writes `/home/server/Public/infoxtractor/app/.env` (mode 0600) from `.env.example` with the password substituted in.
|
||||
5. Verifies `qwen3:14b` is pulled in Ollama.
|
||||
6. Prints a hint to open UFW for port 8994 on the LAN subnet if it's missing.
|
||||
|
||||
After the script finishes, add the deploy remote to the local repo:
|
||||
|
||||
```bash
|
||||
git remote add server ssh://server@192.168.68.42/home/server/Public/infoxtractor/repos.git
|
||||
```
|
||||
|
||||
## Normal deploy workflow
|
||||
|
||||
```bash
|
||||
# after merging a feat branch into main
|
||||
git push server main
|
||||
|
||||
# tail the server's deploy log
|
||||
ssh server@192.168.68.42 "tail -f /tmp/infoxtractor-deploy.log"
|
||||
|
||||
# healthz gate (the post-receive hook also waits up to 60 s for this)
|
||||
curl http://192.168.68.42:8994/healthz
|
||||
|
||||
# end-to-end smoke — this IS the real acceptance test
|
||||
python scripts/e2e_smoke.py
|
||||
```
|
||||
|
||||
If the post-receive hook exits non-zero (healthz never reaches 200), the deploy is considered failed. The previous container keeps running (the hook swaps via `docker compose up -d --build`, which first builds the new image and only swaps if the build succeeds; if the new container fails `/healthz`, it's still up but broken). Investigate with `docker compose logs --tail 200` in `${APP_DIR}` and either fix forward or revert (see below).
|
||||
|
||||
## Rollback
|
||||
|
||||
Never force-push `main`. Rollbacks happen as **forward commits** via `git revert`:
|
||||
|
||||
```bash
|
||||
git revert HEAD # creates a revert commit for the last change
|
||||
git push forgejo main
|
||||
git push server main
|
||||
```
|
||||
|
||||
## First deploy
|
||||
|
||||
- **Date:** 2026-04-18
|
||||
- **Commit:** `fix/ollama-extract-json` (#36, the last of several Docker/ops follow-ups after PR #27 shipped the initial Dockerfile)
|
||||
- **`/healthz`:** all three probes (`postgres`, `ollama`, `ocr`) green. First-pass took ~7 min for the fresh container because Surya's recognition (1.34 GB) + detection (73 MB) models download from HuggingFace on first run; subsequent rebuilds reuse the named volumes declared in `docker-compose.yml` and come up in <30 s.
|
||||
- **E2E extraction:** `bank_statement_header` against `tests/fixtures/synthetic_giro.pdf` with Paperless-style texts:
|
||||
- Pipeline completes in **35 s**.
|
||||
- Extracted: `bank_name=DKB`, `account_iban=DE89370400440532013000`, `currency=EUR`, `opening_balance=1234.56`, `closing_balance=1450.22`, `statement_date=2026-03-31`, `statement_period_end=2026-03-31`, `statement_period_start=2026-03-01`, `account_type=null`.
|
||||
- Provenance: 8 / 9 leaf fields have sources; 7 / 8 `provenance_verified` and `text_agreement` are True. `statement_period_start` shows up in the OCR but normalisation fails (dateutil picks a different interpretation of the cited day); to be chased in a follow-up.
|
||||
|
||||
### Docker-ops follow-ups that landed during the first deploy
|
||||
|
||||
All small, each merged as its own PR. In commit order after the scaffold (#27):
|
||||
|
||||
- **#31** `fix(docker): uv via standalone installer` — Python 3.12 on Ubuntu 22.04 drops `distutils`; Ubuntu's pip needed it. Switched to the `uv` standalone installer, which has no pip dependency.
|
||||
- **#32** `fix(docker): include README.md in the uv sync COPY` — `hatchling` validates the readme file exists when resolving the editable project install.
|
||||
- **#33** `fix(compose): drop runtime: nvidia` — the deploy host's Docker daemon doesn't register a named `nvidia` runtime; `deploy.resources.devices` is sufficient and matches immich-ml.
|
||||
- **#34** `fix(deploy): network_mode: host` — `postgis` is bound to `127.0.0.1` on the host (security hardening T12). `host.docker.internal` points at the bridge gateway, not loopback, so the container couldn't reach postgis. Goldstein uses the same pattern.
|
||||
- **#35** `fix(deps): pin surya-ocr ^0.17` — earlier cu124 torch pin had forced surya to 0.14.1, which breaks our `surya.foundation` import and needs a transformers version that lacks `QuantizedCacheConfig`.
|
||||
- **#36** `fix(genai): drop Ollama format flag; extract trailing JSON` — Ollama 0.11.8 segfaults on Pydantic JSON Schemas (`$ref`, `anyOf`, `pattern`), and `format="json"` terminates reasoning models (qwen3) at `{}` because their `<think>…</think>` chain-of-thought isn't valid JSON. Omit the flag, inject the schema into the system prompt, extract the outermost `{…}` balanced block from the response.
|
||||
- **volumes** — named `ix_surya_cache` + `ix_hf_cache` mount `/root/.cache/datalab` + `/root/.cache/huggingface` so rebuilds don't re-download ~1.5 GB of model weights.
|
||||
|
||||
Production notes:
|
||||
|
||||
- `IX_DEFAULT_MODEL=qwen3:14b` (already pulled on the host). Spec listed `gpt-oss:20b` as a concrete example; swapped to keep the deploy on-prem without an extra `ollama pull`.
|
||||
- Torch 2.11 default cu13 wheels fall back to CPU against the host's CUDA 12.4 driver — Surya runs on CPU. Expected inference times: seconds per page. Upgrading the NVIDIA driver (or pinning a cu12-compatible torch wheel newer than 2.7) will unlock GPU with no code changes.
|
||||
|
||||
## E2E smoke test (`scripts/e2e_smoke.py`)
|
||||
|
||||
What it does (from the Mac):
|
||||
|
||||
1. Checks `/healthz`.
|
||||
2. Starts a tiny HTTP server on the Mac's LAN IP serving `tests/fixtures/synthetic_giro.pdf`.
|
||||
3. Submits a `POST /jobs` with `use_case=bank_statement_header`, the fixture URL in `context.files`, and a Paperless-style OCR text in `context.texts` (to exercise the `text_agreement` cross-check).
|
||||
4. Polls `GET /jobs/{id}` every 2 s until terminal or 120 s timeout.
|
||||
5. Asserts: `status=="done"`, `bank_name` non-empty, `provenance.fields["result.closing_balance"].provenance_verified=True`, `text_agreement=True`, total elapsed `< 60s`.
|
||||
|
||||
Non-zero exit means the deploy is not healthy. Roll back via `git revert HEAD`.
|
||||
|
||||
## Operational checklists
|
||||
|
||||
### After `ollama pull` on the host
|
||||
|
||||
The `IX_DEFAULT_MODEL` env var on the server's `.env` must match something in `ollama list`. Changing the default means:
|
||||
|
||||
1. Edit `/home/server/Public/infoxtractor/app/.env` → `IX_DEFAULT_MODEL=<new>`.
|
||||
2. `docker compose --project-directory /home/server/Public/infoxtractor/app restart`.
|
||||
3. `curl http://192.168.68.42:8994/healthz` → confirm `ollama: ok`.
|
||||
|
||||
### If `/healthz` shows `ollama: degraded`
|
||||
|
||||
`qwen3:14b` (or the configured default) is not pulled. On the host:
|
||||
```bash
|
||||
ssh server@192.168.68.42 "docker exec ollama ollama pull qwen3:14b"
|
||||
```
|
||||
|
||||
### If `/healthz` shows `ocr: fail`
|
||||
|
||||
Surya couldn't initialize (model missing, CUDA unavailable, OOM). First run can be slow — models download on first call. Check container logs:
|
||||
```bash
|
||||
ssh server@192.168.68.42 "docker logs infoxtractor --tail 200"
|
||||
```
|
||||
|
||||
### If the container fails to start
|
||||
|
||||
```bash
|
||||
ssh server@192.168.68.42 "tail -100 /tmp/infoxtractor-deploy.log"
|
||||
ssh server@192.168.68.42 "docker compose -f /home/server/Public/infoxtractor/app/docker-compose.yml logs --tail 200"
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
- Monitoring dashboard auto-discovers via the `infrastructure.web_url` label on the container: `http://192.168.68.42:8001` → "infoxtractor" card.
|
||||
- Backup opt-in via `backup.enable=true` + `backup.type=postgres` + `backup.name=infoxtractor` labels. The daily backup script picks up the `infoxtractor` Postgres database automatically.
|
||||
|
||||
## Ports
|
||||
|
||||
| Port | Direction | Source | Service |
|
||||
|------|-----------|--------|---------|
|
||||
| 8994/tcp | ALLOW | 192.168.68.0/24 | ix REST + healthz (LAN only; not publicly exposed) |
|
||||
|
||||
No VPS Caddy entry; no `infrastructure.docs_url` label — this is an internal service.
|
||||
|
|
@ -85,6 +85,7 @@ class FileRef(BaseModel):
|
|||
url: str # http(s):// or file://
|
||||
headers: dict[str, str] = {} # e.g. {"Authorization": "Token …"}
|
||||
max_bytes: Optional[int] = None # per-file override; defaults to IX_FILE_MAX_BYTES
|
||||
display_name: Optional[str] = None # UI-only metadata; client-provided filename for display (pipeline ignores)
|
||||
|
||||
class Options(BaseModel):
|
||||
ocr: OCROptions = OCROptions()
|
||||
|
|
@ -108,6 +109,25 @@ class ProvenanceOptions(BaseModel):
|
|||
|
||||
**Dropped from spec (no-ops under MVP):** `OCROptions.computer_vision_scaling_factor`, `include_page_tags` (always on), `GenAIOptions.use_vision`/`vision_scaling_factor`/`vision_detail`/`reasoning_effort`, `ProvenanceOptions.granularity`/`include_bounding_boxes`/`source_type`/`min_confidence`, `RequestIX.version`.
|
||||
|
||||
**Ad-hoc use cases (post-MVP add-on).** `RequestIX` carries an optional `use_case_inline: InlineUseCase | None = None`. When set, the pipeline builds the `(Request, Response)` Pydantic class pair on the fly from that inline definition and **skips the registry lookup entirely** — the `use_case` field becomes a free-form label (still required for metrics / logging). Inline definitions look like:
|
||||
|
||||
```python
|
||||
class UseCaseFieldDef(BaseModel):
|
||||
name: str # valid Python identifier
|
||||
type: Literal["str", "int", "float", "decimal", "date", "datetime", "bool"]
|
||||
required: bool = False
|
||||
description: str | None = None
|
||||
choices: list[str] | None = None # str-typed fields only; builds Literal[*choices]
|
||||
|
||||
class InlineUseCase(BaseModel):
|
||||
use_case_name: str
|
||||
system_prompt: str
|
||||
default_model: str | None = None
|
||||
fields: list[UseCaseFieldDef]
|
||||
```
|
||||
|
||||
Precedence: `use_case_inline` wins when both are set. Structural errors (dup field name, invalid identifier, `choices` on a non-str type, empty fields list) raise `IX_001_001` (same code as registry miss). The builder lives in `ix.use_cases.inline.build_use_case_classes` and returns fresh classes per call — the pipeline never caches them.
|
||||
|
||||
### ResponseIX
|
||||
|
||||
Identical to spec §2.2 except `FieldProvenance` gains two fields:
|
||||
|
|
@ -206,14 +226,15 @@ Callers that prefer direct SQL (the `pg_queue_adapter` contract): insert a row w
|
|||
| `POST` | `/jobs` | Body = `RequestIX` (+ optional `callback_url`). → `201 {job_id, ix_id, status: "pending"}`. Idempotent on `(ix_client_id, request_id)` — same pair returns the existing `job_id` with `200`. |
|
||||
| `GET` | `/jobs/{job_id}` | → full `Job`. Source of truth regardless of submission path or callback outcome. |
|
||||
| `GET` | `/jobs?client_id=…&request_id=…` | Lookup-by-correlation (caller idempotency helper). The pair is UNIQUE in the table → at most one match. Returns the job or `404`. |
|
||||
| `GET` | `/healthz` | `{postgres, ollama, ocr}`. See below for semantics. Used by `infrastructure` monitoring dashboard. |
|
||||
| `GET` | `/healthz` | `{postgres, ollama, ocr, ocr_gpu}`. See below for semantics. Used by `infrastructure` monitoring dashboard. `ocr_gpu` is additive metadata (not part of the gate). |
|
||||
| `GET` | `/metrics` | Counters over the last 24 hours: `jobs_pending`, `jobs_running`, `jobs_done_24h`, `jobs_error_24h`, per-use-case avg seconds over the same window. Plain JSON, no Prometheus format for MVP. |
|
||||
|
||||
**`/healthz` semantics:**
|
||||
- `postgres`: `SELECT 1` on the job store pool; `ok` iff the query returns within 2 s.
|
||||
- `ollama`: `GET {IX_OLLAMA_URL}/api/tags` within 5 s; `ok` iff reachable AND the default model (`IX_DEFAULT_MODEL`) is listed in the tags response; `degraded` iff reachable but the model is missing (ops action: run `ollama pull <model>` on the host); `fail` on any other error.
|
||||
- `ocr`: `SuryaOCRClient.selfcheck()` — returns `ok` iff CUDA is available and the Surya text-recognition model is loaded into GPU memory at process start. `fail` on any error.
|
||||
- Overall HTTP status: `200` iff all three are `ok`; `503` otherwise. The monitoring dashboard only surfaces `200`/`non-200`.
|
||||
- `ocr_gpu`: `true | false | null`. Additive metadata: reports whether the OCR client observed `torch.cuda.is_available() == True` at first warm-up. `null` means not yet probed (fresh process, fake client, etc.). The UI reads this to surface a CPU-mode slowdown notice; never part of the 200/503 gate.
|
||||
- Overall HTTP status: `200` iff all three core statuses (`postgres`, `ollama`, `ocr`) are `ok`; `503` otherwise. `ocr_gpu` does not affect the gate. The monitoring dashboard only surfaces `200`/`non-200`.
|
||||
|
||||
**Callback delivery** (when `callback_url` is set): one POST of the full `Job` body, 10 s timeout. 2xx → `callback_status='delivered'`. Anything else → `'failed'`. No retry. Callers always have `GET /jobs/{id}` as the authoritative fallback.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
[project]
|
||||
name = "infoxtractor"
|
||||
version = "0.1.0"
|
||||
# Released 2026-04-18 with the first live deploy of the MVP. See
|
||||
# docs/deployment.md §"First deploy" for the commit + /healthz times.
|
||||
description = "Async on-prem LLM-powered structured information extraction microservice"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
|
|
@ -27,14 +29,24 @@ dependencies = [
|
|||
"pillow>=10.2,<11.0",
|
||||
"python-magic>=0.4.27",
|
||||
"python-dateutil>=2.9",
|
||||
|
||||
# UI (HTMX + Jinja2 templates served from /ui). Both arrive as transitive
|
||||
# deps via FastAPI/Starlette already, but we pin explicitly so the import
|
||||
# surface is owned by us. python-multipart backs FastAPI's `Form()` /
|
||||
# `UploadFile` parsing — required by `/ui/jobs` submissions.
|
||||
"jinja2>=3.1",
|
||||
"aiofiles>=24.1",
|
||||
"python-multipart>=0.0.12",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
ocr = [
|
||||
# Real OCR engine — pulls torch + CUDA wheels. Kept optional so CI
|
||||
# (no GPU) can install the base package without the model deps.
|
||||
"surya-ocr>=0.9",
|
||||
"torch>=2.4",
|
||||
# Real OCR engine. Kept optional so CI (no GPU) can install the base
|
||||
# package without the model deps.
|
||||
# surya >= 0.17 is required: the client code uses the
|
||||
# `surya.foundation` module, which older releases don't expose.
|
||||
"surya-ocr>=0.17,<0.18",
|
||||
"torch>=2.7",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=8.3",
|
||||
|
|
@ -44,6 +56,11 @@ dev = [
|
|||
"mypy>=1.13",
|
||||
]
|
||||
|
||||
# Note: the default pypi torch ships cu13 wheels, which emit a
|
||||
# UserWarning and fall back to CPU against the deploy host's CUDA 12.4
|
||||
# driver. Surya then runs on CPU — slower but correct for MVP. A future
|
||||
# driver upgrade unlocks GPU Surya with no code changes.
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
|
|
|||
66
scripts/create_fixture_pdf.py
Normal file
66
scripts/create_fixture_pdf.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""Build the synthetic E2E fixture PDF at ``tests/fixtures/synthetic_giro.pdf``.
|
||||
|
||||
Re-runnable on demand. Output bytes are stable across runs in page
|
||||
content, layout, and text — only the PDF's embedded timestamps change,
|
||||
which pipeline tests don't read. The committed fixture is what CI
|
||||
consumes; re-run this script locally if you change the ground truth.
|
||||
|
||||
Contents: one A4 portrait page with six known strings placed at fixed
|
||||
positions near the top. The goal is reproducible ground truth, not a
|
||||
realistic bank statement. The pipeline's fake OCR client is seeded with
|
||||
those same strings (at plausible bboxes) so the E2E test can assert
|
||||
exact matches.
|
||||
|
||||
Usage::
|
||||
|
||||
uv run python scripts/create_fixture_pdf.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import fitz # PyMuPDF
|
||||
|
||||
OUT_PATH = (
|
||||
Path(__file__).resolve().parent.parent / "tests" / "fixtures" / "synthetic_giro.pdf"
|
||||
)
|
||||
|
||||
LINES: list[str] = [
|
||||
"DKB",
|
||||
"IBAN: DE89370400440532013000",
|
||||
"Statement period: 01.03.2026 - 31.03.2026",
|
||||
"Opening balance: 1234.56 EUR",
|
||||
"Closing balance: 1450.22 EUR",
|
||||
"Statement date: 31.03.2026",
|
||||
]
|
||||
|
||||
|
||||
def build() -> None:
|
||||
doc = fitz.open()
|
||||
# A4 @ 72 dpi -> 595 x 842 points.
|
||||
page = doc.new_page(width=595, height=842)
|
||||
y = 72.0
|
||||
for line in LINES:
|
||||
page.insert_text(
|
||||
(72.0, y),
|
||||
line,
|
||||
fontsize=12,
|
||||
fontname="helv",
|
||||
)
|
||||
y += 24.0
|
||||
OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
# deflate=False + garbage=0 keeps the output byte-stable.
|
||||
doc.save(
|
||||
str(OUT_PATH),
|
||||
deflate=False,
|
||||
deflate_images=False,
|
||||
garbage=0,
|
||||
clean=False,
|
||||
)
|
||||
doc.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
build()
|
||||
print(f"wrote {OUT_PATH}")
|
||||
210
scripts/e2e_smoke.py
Executable file
210
scripts/e2e_smoke.py
Executable file
|
|
@ -0,0 +1,210 @@
|
|||
"""End-to-end smoke test against the deployed infoxtractor service.
|
||||
|
||||
Uploads a synthetic bank-statement fixture, polls for completion, and asserts
|
||||
the provenance flags per spec §12 E2E. Intended to run from the Mac after
|
||||
every `git push server main` as the deploy gate.
|
||||
|
||||
Prerequisites:
|
||||
- The service is running and reachable at --base-url (default
|
||||
http://192.168.68.42:8994).
|
||||
- The fixture `tests/fixtures/synthetic_giro.pdf` is present.
|
||||
- The Mac and the server are on the same LAN (the server must be able to
|
||||
reach the Mac to download the fixture).
|
||||
|
||||
Exit codes:
|
||||
0 all assertions passed within the timeout
|
||||
1 at least one assertion failed
|
||||
2 the job never reached a terminal state in time
|
||||
3 the service was unreachable or returned an unexpected error
|
||||
|
||||
Usage:
|
||||
python scripts/e2e_smoke.py
|
||||
python scripts/e2e_smoke.py --base-url http://localhost:8994
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.server
|
||||
import json
|
||||
import socket
|
||||
import socketserver
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_BASE_URL = "http://192.168.68.42:8994"
|
||||
FIXTURE = Path(__file__).parent.parent / "tests" / "fixtures" / "synthetic_giro.pdf"
|
||||
TIMEOUT_SECONDS = 120
|
||||
POLL_INTERVAL_SECONDS = 2
|
||||
|
||||
|
||||
def find_lan_ip() -> str:
|
||||
"""Return the Mac's LAN IP that the server can reach."""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
# 192.168.68.42 is the server; getting the default route towards it
|
||||
# yields the NIC with the matching subnet.
|
||||
s.connect(("192.168.68.42", 80))
|
||||
return s.getsockname()[0]
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
def serve_fixture_in_background(fixture: Path) -> tuple[str, threading.Event]:
|
||||
"""Serve the fixture on a temporary HTTP server; return the URL and a stop event."""
|
||||
if not fixture.exists():
|
||||
print(f"FIXTURE MISSING: {fixture}", file=sys.stderr)
|
||||
sys.exit(3)
|
||||
|
||||
directory = fixture.parent
|
||||
filename = fixture.name
|
||||
lan_ip = find_lan_ip()
|
||||
|
||||
class Handler(http.server.SimpleHTTPRequestHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, directory=str(directory), **kwargs)
|
||||
|
||||
def log_message(self, format: str, *args) -> None: # quiet
|
||||
pass
|
||||
|
||||
# Pick any free port.
|
||||
httpd = socketserver.TCPServer((lan_ip, 0), Handler)
|
||||
port = httpd.server_address[1]
|
||||
url = f"http://{lan_ip}:{port}/{filename}"
|
||||
stop = threading.Event()
|
||||
|
||||
def _serve():
|
||||
try:
|
||||
while not stop.is_set():
|
||||
httpd.handle_request()
|
||||
finally:
|
||||
httpd.server_close()
|
||||
|
||||
# Run in a thread. Use a loose timeout so handle_request returns when stop is set.
|
||||
httpd.timeout = 0.5
|
||||
t = threading.Thread(target=_serve, daemon=True)
|
||||
t.start()
|
||||
return url, stop
|
||||
|
||||
|
||||
def post_job(base_url: str, file_url: str, client_id: str, request_id: str) -> dict:
|
||||
# Include a Paperless-style OCR of the fixture as context.texts so the
|
||||
# text_agreement cross-check has something to compare against.
|
||||
paperless_text = (
|
||||
"DKB\n"
|
||||
"DE89370400440532013000\n"
|
||||
"Statement period: 01.03.2026 - 31.03.2026\n"
|
||||
"Opening balance: 1234.56 EUR\n"
|
||||
"Closing balance: 1450.22 EUR\n"
|
||||
"31.03.2026\n"
|
||||
)
|
||||
payload = {
|
||||
"use_case": "bank_statement_header",
|
||||
"ix_client_id": client_id,
|
||||
"request_id": request_id,
|
||||
"context": {
|
||||
"files": [file_url],
|
||||
"texts": [paperless_text],
|
||||
},
|
||||
}
|
||||
req = urllib.request.Request(
|
||||
f"{base_url}/jobs",
|
||||
data=json.dumps(payload).encode("utf-8"),
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
|
||||
|
||||
def get_job(base_url: str, job_id: str) -> dict:
|
||||
req = urllib.request.Request(f"{base_url}/jobs/{job_id}")
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--base-url", default=DEFAULT_BASE_URL)
|
||||
parser.add_argument("--timeout", type=int, default=TIMEOUT_SECONDS)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity-check the service is up.
|
||||
try:
|
||||
with urllib.request.urlopen(f"{args.base_url}/healthz", timeout=5) as resp:
|
||||
health = json.loads(resp.read().decode("utf-8"))
|
||||
print(f"healthz: {health}")
|
||||
except urllib.error.URLError as e:
|
||||
print(f"service unreachable: {e}", file=sys.stderr)
|
||||
return 3
|
||||
|
||||
fixture_url, stop_server = serve_fixture_in_background(FIXTURE)
|
||||
print(f"serving fixture at {fixture_url}")
|
||||
|
||||
try:
|
||||
client_id = "e2e_smoke"
|
||||
request_id = f"smoke-{uuid.uuid4().hex[:8]}"
|
||||
submit = post_job(args.base_url, fixture_url, client_id, request_id)
|
||||
job_id = submit["job_id"]
|
||||
print(f"submitted job_id={job_id}")
|
||||
|
||||
started = time.monotonic()
|
||||
last_status = None
|
||||
job = None
|
||||
while time.monotonic() - started < args.timeout:
|
||||
job = get_job(args.base_url, job_id)
|
||||
if job["status"] != last_status:
|
||||
print(f"[{time.monotonic() - started:5.1f}s] status={job['status']}")
|
||||
last_status = job["status"]
|
||||
if job["status"] in ("done", "error"):
|
||||
break
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
else:
|
||||
print(f"FAIL: timed out after {args.timeout}s", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
assert job is not None
|
||||
failed = []
|
||||
|
||||
if job["status"] != "done":
|
||||
failed.append(f"status={job['status']!r} (want 'done')")
|
||||
|
||||
response = job.get("response") or {}
|
||||
if response.get("error"):
|
||||
failed.append(f"response.error={response['error']!r}")
|
||||
|
||||
result = (response.get("ix_result") or {}).get("result") or {}
|
||||
bank = result.get("bank_name")
|
||||
if not isinstance(bank, str) or not bank.strip():
|
||||
failed.append(f"bank_name={bank!r} (want non-empty string)")
|
||||
|
||||
fields = (response.get("provenance") or {}).get("fields") or {}
|
||||
closing = fields.get("result.closing_balance") or {}
|
||||
if not closing.get("provenance_verified"):
|
||||
failed.append(f"closing_balance.provenance_verified={closing.get('provenance_verified')!r}")
|
||||
if closing.get("text_agreement") is not True:
|
||||
failed.append(f"closing_balance.text_agreement={closing.get('text_agreement')!r} (Paperless-style text submitted)")
|
||||
|
||||
elapsed = time.monotonic() - started
|
||||
if elapsed >= 60:
|
||||
failed.append(f"elapsed={elapsed:.1f}s (≥ 60s; slow path)")
|
||||
|
||||
print(json.dumps(result, indent=2, default=str))
|
||||
|
||||
if failed:
|
||||
print("\n".join(f"FAIL: {f}" for f in failed), file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"\nPASS in {elapsed:.1f}s")
|
||||
return 0
|
||||
finally:
|
||||
stop_server.set()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
127
scripts/setup_server.sh
Executable file
127
scripts/setup_server.sh
Executable file
|
|
@ -0,0 +1,127 @@
|
|||
#!/usr/bin/env bash
|
||||
# One-shot server setup for InfoXtractor. Idempotent: safe to re-run.
|
||||
#
|
||||
# Run from the Mac:
|
||||
# IX_POSTGRES_PASSWORD=<pw> ./scripts/setup_server.sh
|
||||
#
|
||||
# What it does on 192.168.68.42:
|
||||
# 1. Creates the bare git repo `/home/server/Public/infoxtractor/repos.git` if missing.
|
||||
# 2. Writes the post-receive hook (or updates it) and makes it executable.
|
||||
# 3. Creates the Postgres role + database on the shared `postgis` container.
|
||||
# 4. Writes `/home/server/Public/infoxtractor/app/.env` (0600) from .env.example.
|
||||
# 5. Verifies `qwen3:14b` is pulled in Ollama.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SERVER="${IX_SERVER:-server@192.168.68.42}"
|
||||
APP_BASE="/home/server/Public/infoxtractor"
|
||||
REPOS_GIT="${APP_BASE}/repos.git"
|
||||
APP_DIR="${APP_BASE}/app"
|
||||
DB_NAME="infoxtractor"
|
||||
DB_USER="infoxtractor"
|
||||
|
||||
if [ -z "${IX_POSTGRES_PASSWORD:-}" ]; then
|
||||
read -r -s -p "Postgres password for role '${DB_USER}': " IX_POSTGRES_PASSWORD
|
||||
echo
|
||||
fi
|
||||
|
||||
if [ -z "${IX_POSTGRES_PASSWORD}" ]; then
|
||||
echo "IX_POSTGRES_PASSWORD is required." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "==> 1/5 Ensuring bare repo + post-receive hook on ${SERVER}"
|
||||
ssh "${SERVER}" bash -s <<EOF
|
||||
set -euo pipefail
|
||||
mkdir -p "${REPOS_GIT}" "${APP_DIR}"
|
||||
if [ ! -f "${REPOS_GIT}/HEAD" ]; then
|
||||
git init --bare "${REPOS_GIT}"
|
||||
fi
|
||||
|
||||
cat >"${REPOS_GIT}/hooks/post-receive" <<'HOOK'
|
||||
#!/usr/bin/env bash
|
||||
set -eo pipefail
|
||||
|
||||
APP_DIR="${APP_DIR}"
|
||||
LOG="/tmp/infoxtractor-deploy.log"
|
||||
|
||||
echo "[\$(date -u '+%FT%TZ')] post-receive start" >> "\$LOG"
|
||||
|
||||
mkdir -p "\$APP_DIR"
|
||||
GIT_WORK_TREE="\$APP_DIR" git --git-dir="${REPOS_GIT}" checkout -f main >> "\$LOG" 2>&1
|
||||
|
||||
cd "\$APP_DIR"
|
||||
docker compose up -d --build >> "\$LOG" 2>&1
|
||||
|
||||
# Deploy gate: /healthz must return 200 within 60 s.
|
||||
for i in \$(seq 1 30); do
|
||||
if curl -fsS http://localhost:8994/healthz > /dev/null 2>&1; then
|
||||
echo "[\$(date -u '+%FT%TZ')] healthz OK" >> "\$LOG"
|
||||
exit 0
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "[\$(date -u '+%FT%TZ')] healthz never reached OK" >> "\$LOG"
|
||||
docker compose logs --tail 100 >> "\$LOG" 2>&1 || true
|
||||
exit 1
|
||||
HOOK
|
||||
|
||||
chmod +x "${REPOS_GIT}/hooks/post-receive"
|
||||
EOF
|
||||
|
||||
echo "==> 2/5 Verifying Ollama has qwen3:14b pulled"
|
||||
if ! ssh "${SERVER}" "docker exec ollama ollama list | awk '{print \$1}' | grep -qx 'qwen3:14b'"; then
|
||||
echo "FAIL: qwen3:14b not found in Ollama. Run: ssh ${SERVER} 'docker exec ollama ollama pull qwen3:14b'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "==> 3/5 Creating Postgres role '${DB_USER}' and database '${DB_NAME}' on postgis container"
|
||||
# Idempotent via DO blocks; uses docker exec to avoid needing psql on the host.
|
||||
ssh "${SERVER}" bash -s <<EOF
|
||||
set -euo pipefail
|
||||
docker exec -i postgis psql -U postgres <<SQL
|
||||
DO \\\$\\\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_roles WHERE rolname = '${DB_USER}') THEN
|
||||
CREATE ROLE ${DB_USER} LOGIN PASSWORD '${IX_POSTGRES_PASSWORD}';
|
||||
ELSE
|
||||
ALTER ROLE ${DB_USER} WITH PASSWORD '${IX_POSTGRES_PASSWORD}';
|
||||
END IF;
|
||||
END
|
||||
\\\$\\\$;
|
||||
SQL
|
||||
|
||||
if ! docker exec -i postgis psql -U postgres -tc "SELECT 1 FROM pg_database WHERE datname = '${DB_NAME}'" | grep -q 1; then
|
||||
docker exec -i postgis createdb -U postgres -O ${DB_USER} ${DB_NAME}
|
||||
fi
|
||||
EOF
|
||||
|
||||
echo "==> 4/5 Writing ${APP_DIR}/.env on the server"
|
||||
# Render .env from the repo's .env.example, substituting the password placeholder.
|
||||
LOCAL_ENV_CONTENT="$(
|
||||
sed "s#<password>#${IX_POSTGRES_PASSWORD}#g" \
|
||||
"$(dirname "$0")/../.env.example"
|
||||
)"
|
||||
# Append the IX_TEST_MODE=production for safety (fake mode stays off).
|
||||
# .env is written atomically and permissioned 0600.
|
||||
ssh "${SERVER}" "install -d -m 0755 '${APP_DIR}' && cat > '${APP_DIR}/.env' <<'ENVEOF'
|
||||
${LOCAL_ENV_CONTENT}
|
||||
ENVEOF
|
||||
chmod 0600 '${APP_DIR}/.env'"
|
||||
|
||||
echo "==> 5/5 Checking UFW rule for port 8994 (LAN only)"
|
||||
ssh "${SERVER}" "sudo ufw status numbered | grep -F 8994" >/dev/null 2>&1 || {
|
||||
echo "NOTE: UFW doesn't yet allow 8994. Run on the server:"
|
||||
echo " sudo ufw allow from 192.168.68.0/24 to any port 8994 proto tcp"
|
||||
}
|
||||
|
||||
echo
|
||||
echo "Done."
|
||||
echo
|
||||
echo "Next steps (on the Mac):"
|
||||
echo " git remote add server ssh://server@192.168.68.42${REPOS_GIT}"
|
||||
echo " git push server main"
|
||||
echo " ssh ${SERVER} 'tail -f /tmp/infoxtractor-deploy.log'"
|
||||
echo " curl http://192.168.68.42:8994/healthz"
|
||||
echo " python scripts/e2e_smoke.py"
|
||||
6
src/ix/adapters/__init__.py
Normal file
6
src/ix/adapters/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
"""Transport adapters — REST (always on) + pg_queue (optional).
|
||||
|
||||
Adapters are thin: they marshal external events into :class:`RequestIX`
|
||||
payloads that land in ``ix_jobs`` as pending rows, and they read back from
|
||||
the same store. They do NOT run the pipeline themselves; the worker does.
|
||||
"""
|
||||
15
src/ix/adapters/pg_queue/__init__.py
Normal file
15
src/ix/adapters/pg_queue/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Postgres queue adapter — ``LISTEN ix_jobs_new`` + 10 s fallback poll.
|
||||
|
||||
This is a secondary transport: a direct-SQL writer can insert a row and
|
||||
``NOTIFY ix_jobs_new, '<job_id>'`` and the worker wakes up within the roundtrip
|
||||
time rather than the 10 s fallback poll. The REST adapter doesn't need the
|
||||
listener because the worker is already running in-process; this exists for
|
||||
external callers who bypass the REST API.
|
||||
"""
|
||||
|
||||
from ix.adapters.pg_queue.listener import (
|
||||
PgQueueListener,
|
||||
asyncpg_dsn_from_sqlalchemy_url,
|
||||
)
|
||||
|
||||
__all__ = ["PgQueueListener", "asyncpg_dsn_from_sqlalchemy_url"]
|
||||
111
src/ix/adapters/pg_queue/listener.py
Normal file
111
src/ix/adapters/pg_queue/listener.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Dedicated asyncpg connection that LISTENs to ``ix_jobs_new``.
|
||||
|
||||
We hold the connection *outside* the SQLAlchemy pool because SQLAlchemy's
|
||||
asyncpg dialect doesn't expose the raw connection in a way that survives
|
||||
the pool's checkout/checkin dance, and LISTEN needs a connection that
|
||||
stays open for the full session to receive asynchronous notifications.
|
||||
|
||||
The adapter contract the worker sees is a single coroutine-factory,
|
||||
``wait_for_work(poll_seconds)``, which completes either when a NOTIFY
|
||||
arrives or when ``poll_seconds`` elapse. The worker doesn't care which
|
||||
woke it — it just goes back to its claim query.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
def asyncpg_dsn_from_sqlalchemy_url(url: str) -> str:
|
||||
"""Strip the SQLAlchemy ``postgresql+asyncpg://`` prefix for raw asyncpg.
|
||||
|
||||
asyncpg's connect() expects the plain ``postgres://user:pass@host/db``
|
||||
shape; the ``+driver`` segment SQLAlchemy adds confuses it. We also
|
||||
percent-decode the password — asyncpg accepts the raw form but not the
|
||||
pre-encoded ``%21`` passwords we sometimes use in dev.
|
||||
"""
|
||||
|
||||
parsed = urlparse(url)
|
||||
scheme = parsed.scheme.split("+", 1)[0]
|
||||
user = unquote(parsed.username) if parsed.username else ""
|
||||
password = unquote(parsed.password) if parsed.password else ""
|
||||
auth = ""
|
||||
if user:
|
||||
auth = f"{user}"
|
||||
if password:
|
||||
auth += f":{password}"
|
||||
auth += "@"
|
||||
netloc = parsed.hostname or ""
|
||||
if parsed.port:
|
||||
netloc += f":{parsed.port}"
|
||||
return f"{scheme}://{auth}{netloc}{parsed.path}"
|
||||
|
||||
|
||||
class PgQueueListener:
|
||||
"""Long-lived asyncpg connection that sets an event on each NOTIFY.
|
||||
|
||||
The worker uses :meth:`wait_for_work` as its ``wait_for_work`` hook:
|
||||
one call resolves when either a NOTIFY is received OR ``timeout``
|
||||
seconds elapse, whichever comes first. The event is cleared after each
|
||||
resolution so subsequent waits don't see stale state.
|
||||
"""
|
||||
|
||||
def __init__(self, dsn: str, channel: str = "ix_jobs_new") -> None:
|
||||
self._dsn = dsn
|
||||
self._channel = channel
|
||||
self._conn: asyncpg.Connection | None = None
|
||||
self._event = asyncio.Event()
|
||||
# Protect add_listener / remove_listener against concurrent
|
||||
# start/stop — shouldn't happen in practice but a stray double-stop
|
||||
# from a lifespan shutdown shouldn't raise ``listener not found``.
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def start(self) -> None:
|
||||
async with self._lock:
|
||||
if self._conn is not None:
|
||||
return
|
||||
self._conn = await asyncpg.connect(self._dsn)
|
||||
await self._conn.add_listener(self._channel, self._on_notify)
|
||||
|
||||
async def stop(self) -> None:
|
||||
async with self._lock:
|
||||
if self._conn is None:
|
||||
return
|
||||
try:
|
||||
await self._conn.remove_listener(self._channel, self._on_notify)
|
||||
finally:
|
||||
await self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
def _on_notify(
|
||||
self,
|
||||
connection: asyncpg.Connection,
|
||||
pid: int,
|
||||
channel: str,
|
||||
payload: str,
|
||||
) -> None:
|
||||
"""asyncpg listener callback — signals the waiter."""
|
||||
|
||||
# We don't care about payload/pid/channel — any NOTIFY on our
|
||||
# channel means "go check for pending rows". Keep the body tiny so
|
||||
# asyncpg's single dispatch loop stays snappy.
|
||||
self._event.set()
|
||||
|
||||
async def wait_for_work(self, timeout: float) -> None:
|
||||
"""Resolve when a NOTIFY arrives or ``timeout`` seconds pass.
|
||||
|
||||
We wait on the event with a timeout. ``asyncio.wait_for`` raises
|
||||
:class:`asyncio.TimeoutError` on expiry; we swallow it because the
|
||||
worker treats "either signal" identically. The event is cleared
|
||||
after every wait so the next call starts fresh.
|
||||
"""
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=timeout)
|
||||
except TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self._event.clear()
|
||||
5
src/ix/adapters/rest/__init__.py
Normal file
5
src/ix/adapters/rest/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""FastAPI REST adapter — POST /jobs / GET /jobs / GET /healthz / GET /metrics.
|
||||
|
||||
Routes are defined in ``routes.py``. The ``create_app`` factory in
|
||||
``ix.app`` wires them up alongside the worker lifespan.
|
||||
"""
|
||||
240
src/ix/adapters/rest/routes.py
Normal file
240
src/ix/adapters/rest/routes.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
"""REST routes (spec §5).
|
||||
|
||||
The routes depend on two injected objects:
|
||||
|
||||
* a session factory (``get_session_factory_dep``): swapped in tests so we can
|
||||
use the fixture's per-test engine instead of the lazy process-wide one in
|
||||
``ix.store.engine``.
|
||||
* a :class:`Probes` bundle (``get_probes``): each probe returns the
|
||||
per-subsystem state string used by ``/healthz``. Tests inject canned
|
||||
probes; Chunk 4 wires the real Ollama/Surya ones.
|
||||
|
||||
``/healthz`` has a strict 2-second postgres timeout — we use an
|
||||
``asyncio.wait_for`` around a ``SELECT 1`` roundtrip so a broken pool or a
|
||||
hung connection can't wedge the healthcheck endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Annotated, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from ix.adapters.rest.schemas import HealthStatus, JobSubmitResponse, MetricsResponse
|
||||
from ix.contracts.job import Job
|
||||
from ix.contracts.request import RequestIX
|
||||
from ix.store import jobs_repo
|
||||
from ix.store.engine import get_session_factory
|
||||
from ix.store.models import IxJob
|
||||
|
||||
|
||||
@dataclass
|
||||
class Probes:
|
||||
"""Injected subsystem-probe callables for ``/healthz``.
|
||||
|
||||
Each callable returns the literal status string expected in the body.
|
||||
Probes are sync by design — none of the real ones need awaits today and
|
||||
keeping them sync lets tests pass plain lambdas. Real probes that need
|
||||
async work run the call through ``asyncio.run_in_executor`` inside the
|
||||
callable (Chunk 4).
|
||||
|
||||
``ocr_gpu`` is additive metadata for the UI (not a health gate): returns
|
||||
``True`` iff the OCR client reports CUDA is available, ``False`` for
|
||||
explicit CPU-mode, ``None`` if unknown (fake client, not yet warmed up).
|
||||
"""
|
||||
|
||||
ollama: Callable[[], Literal["ok", "degraded", "fail"]]
|
||||
ocr: Callable[[], Literal["ok", "fail"]]
|
||||
ocr_gpu: Callable[[], bool | None] = field(default=lambda: None)
|
||||
|
||||
|
||||
def get_session_factory_dep() -> async_sessionmaker[AsyncSession]:
|
||||
"""Default DI: the process-wide store factory. Tests override this."""
|
||||
|
||||
return get_session_factory()
|
||||
|
||||
|
||||
def get_probes() -> Probes:
|
||||
"""Default DI: a pair of ``fail`` probes.
|
||||
|
||||
Production wiring (Chunk 4) overrides this factory with real Ollama +
|
||||
Surya probes at app-startup time. Integration tests override via
|
||||
``app.dependency_overrides[get_probes]`` with a canned ``ok`` pair.
|
||||
The default here ensures a mis-wired deployment surfaces clearly in
|
||||
``/healthz`` rather than claiming everything is fine by accident.
|
||||
"""
|
||||
|
||||
return Probes(ollama=lambda: "fail", ocr=lambda: "fail")
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/jobs", response_model=JobSubmitResponse, status_code=201)
|
||||
async def submit_job(
|
||||
request: RequestIX,
|
||||
response: Response,
|
||||
session_factory: Annotated[
|
||||
async_sessionmaker[AsyncSession], Depends(get_session_factory_dep)
|
||||
],
|
||||
) -> JobSubmitResponse:
|
||||
"""Submit a new job.
|
||||
|
||||
Per spec §5: 201 on first insert, 200 on idempotent re-submit of an
|
||||
existing ``(client_id, request_id)`` pair. We detect the second case by
|
||||
snapshotting the pre-insert row set and comparing ``created_at``.
|
||||
"""
|
||||
|
||||
async with session_factory() as session:
|
||||
existing = await jobs_repo.get_by_correlation(
|
||||
session, request.ix_client_id, request.request_id
|
||||
)
|
||||
job = await jobs_repo.insert_pending(
|
||||
session, request, callback_url=request.callback_url
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if existing is not None:
|
||||
# Idempotent re-submit — flip to 200. FastAPI's declared status_code
|
||||
# is 201, but setting response.status_code overrides it per-call.
|
||||
response.status_code = 200
|
||||
|
||||
return JobSubmitResponse(job_id=job.job_id, ix_id=job.ix_id, status=job.status)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=Job)
|
||||
async def get_job(
|
||||
job_id: UUID,
|
||||
session_factory: Annotated[
|
||||
async_sessionmaker[AsyncSession], Depends(get_session_factory_dep)
|
||||
],
|
||||
) -> Job:
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.get(session, job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="job not found")
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/jobs", response_model=Job)
|
||||
async def lookup_job_by_correlation(
|
||||
client_id: Annotated[str, Query(...)],
|
||||
request_id: Annotated[str, Query(...)],
|
||||
session_factory: Annotated[
|
||||
async_sessionmaker[AsyncSession], Depends(get_session_factory_dep)
|
||||
],
|
||||
) -> Job:
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.get_by_correlation(session, client_id, request_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="job not found")
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/healthz")
|
||||
async def healthz(
|
||||
response: Response,
|
||||
session_factory: Annotated[
|
||||
async_sessionmaker[AsyncSession], Depends(get_session_factory_dep)
|
||||
],
|
||||
probes: Annotated[Probes, Depends(get_probes)],
|
||||
) -> HealthStatus:
|
||||
"""Per spec §5: postgres / ollama / ocr; 200 iff all three == ok."""
|
||||
|
||||
postgres_state: Literal["ok", "fail"] = "fail"
|
||||
try:
|
||||
async def _probe() -> None:
|
||||
async with session_factory() as session:
|
||||
await session.execute(text("SELECT 1"))
|
||||
|
||||
await asyncio.wait_for(_probe(), timeout=2.0)
|
||||
postgres_state = "ok"
|
||||
except Exception:
|
||||
postgres_state = "fail"
|
||||
|
||||
try:
|
||||
ollama_state = probes.ollama()
|
||||
except Exception:
|
||||
ollama_state = "fail"
|
||||
try:
|
||||
ocr_state = probes.ocr()
|
||||
except Exception:
|
||||
ocr_state = "fail"
|
||||
|
||||
try:
|
||||
ocr_gpu_state: bool | None = probes.ocr_gpu()
|
||||
except Exception:
|
||||
ocr_gpu_state = None
|
||||
|
||||
body = HealthStatus(
|
||||
postgres=postgres_state,
|
||||
ollama=ollama_state,
|
||||
ocr=ocr_state,
|
||||
ocr_gpu=ocr_gpu_state,
|
||||
)
|
||||
if postgres_state != "ok" or ollama_state != "ok" or ocr_state != "ok":
|
||||
response.status_code = 503
|
||||
return body
|
||||
|
||||
|
||||
@router.get("/metrics", response_model=MetricsResponse)
|
||||
async def metrics(
|
||||
session_factory: Annotated[
|
||||
async_sessionmaker[AsyncSession], Depends(get_session_factory_dep)
|
||||
],
|
||||
) -> MetricsResponse:
|
||||
"""Counters over the last 24h — plain JSON per spec §5."""
|
||||
|
||||
since = datetime.now(UTC) - timedelta(hours=24)
|
||||
|
||||
async with session_factory() as session:
|
||||
pending = await session.scalar(
|
||||
select(func.count()).select_from(IxJob).where(IxJob.status == "pending")
|
||||
)
|
||||
running = await session.scalar(
|
||||
select(func.count()).select_from(IxJob).where(IxJob.status == "running")
|
||||
)
|
||||
done_24h = await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(IxJob)
|
||||
.where(IxJob.status == "done", IxJob.finished_at >= since)
|
||||
)
|
||||
error_24h = await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(IxJob)
|
||||
.where(IxJob.status == "error", IxJob.finished_at >= since)
|
||||
)
|
||||
|
||||
# Per-use-case average seconds. ``request`` is JSONB, so we dig out
|
||||
# the use_case key via ->>. Only consider rows that both started and
|
||||
# finished in the window (can't compute elapsed otherwise).
|
||||
rows = (
|
||||
await session.execute(
|
||||
text(
|
||||
"SELECT request->>'use_case' AS use_case, "
|
||||
"AVG(EXTRACT(EPOCH FROM (finished_at - started_at))) "
|
||||
"FROM ix_jobs "
|
||||
"WHERE status='done' AND finished_at IS NOT NULL "
|
||||
"AND started_at IS NOT NULL AND finished_at >= :since "
|
||||
"GROUP BY request->>'use_case'"
|
||||
),
|
||||
{"since": since},
|
||||
)
|
||||
).all()
|
||||
|
||||
by_use_case = {row[0]: float(row[1]) for row in rows if row[0] is not None}
|
||||
|
||||
return MetricsResponse(
|
||||
jobs_pending=int(pending or 0),
|
||||
jobs_running=int(running or 0),
|
||||
jobs_done_24h=int(done_24h or 0),
|
||||
jobs_error_24h=int(error_24h or 0),
|
||||
by_use_case_seconds=by_use_case,
|
||||
)
|
||||
59
src/ix/adapters/rest/schemas.py
Normal file
59
src/ix/adapters/rest/schemas.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""REST-adapter request / response bodies.
|
||||
|
||||
Most payloads reuse the core contracts directly (:class:`RequestIX`,
|
||||
:class:`Job`). The only adapter-specific shape is the lightweight POST /jobs
|
||||
response (`job_id`, `ix_id`, `status`) — callers don't need the full Job
|
||||
envelope back on submit; they poll ``GET /jobs/{id}`` for that.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class JobSubmitResponse(BaseModel):
|
||||
"""What POST /jobs returns: just enough to poll or subscribe to callbacks."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
job_id: UUID
|
||||
ix_id: str
|
||||
status: Literal["pending", "running", "done", "error"]
|
||||
|
||||
|
||||
class HealthStatus(BaseModel):
|
||||
"""Body of GET /healthz.
|
||||
|
||||
Each field reports per-subsystem state. Overall HTTP status is 200 iff
|
||||
every of the three core status keys is ``"ok"`` (spec §5). ``ollama`` can
|
||||
be ``"degraded"`` when the backend is reachable but the default model
|
||||
isn't pulled — monitoring surfaces that as non-200.
|
||||
|
||||
``ocr_gpu`` is additive metadata, not part of the health gate: it reports
|
||||
whether the Surya OCR client observed ``torch.cuda.is_available() == True``
|
||||
on first warm-up. ``None`` means we haven't probed yet (fresh process,
|
||||
fake client, or warm_up hasn't happened). The UI reads this to surface a
|
||||
CPU-mode slowdown warning to users.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
postgres: Literal["ok", "fail"]
|
||||
ollama: Literal["ok", "degraded", "fail"]
|
||||
ocr: Literal["ok", "fail"]
|
||||
ocr_gpu: bool | None = None
|
||||
|
||||
|
||||
class MetricsResponse(BaseModel):
|
||||
"""Body of GET /metrics — plain JSON (no Prometheus format for MVP)."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
jobs_pending: int
|
||||
jobs_running: int
|
||||
jobs_done_24h: int
|
||||
jobs_error_24h: int
|
||||
by_use_case_seconds: dict[str, float]
|
||||
260
src/ix/app.py
Normal file
260
src/ix/app.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
"""FastAPI app factory + lifespan.
|
||||
|
||||
``create_app()`` wires the REST router on top of a lifespan that spawns the
|
||||
worker loop (Task 3.5) and the pg_queue listener (Task 3.6). Tests that
|
||||
don't care about the worker call ``create_app(spawn_worker=False)`` so the
|
||||
lifespan returns cleanly.
|
||||
|
||||
Task 4.3 fills in the production wiring:
|
||||
|
||||
* Factories (``make_genai_client`` / ``make_ocr_client``) pick between
|
||||
fakes (``IX_TEST_MODE=fake``) and real Ollama/Surya clients.
|
||||
* ``/healthz`` probes call ``selfcheck()`` on the active clients. In
|
||||
``fake`` mode they always report ok.
|
||||
* The worker's :class:`Pipeline` is built once per spawn with the real
|
||||
chain of Steps; each call to the injected ``pipeline_factory`` returns
|
||||
a fresh Pipeline so per-request state stays isolated.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from ix.adapters.rest.routes import Probes, get_probes
|
||||
from ix.adapters.rest.routes import router as rest_router
|
||||
from ix.config import AppConfig, get_config
|
||||
from ix.genai import make_genai_client
|
||||
from ix.genai.client import GenAIClient
|
||||
from ix.ocr import make_ocr_client
|
||||
from ix.ocr.client import OCRClient
|
||||
from ix.pipeline.genai_step import GenAIStep
|
||||
from ix.pipeline.ocr_step import OCRStep
|
||||
from ix.pipeline.pipeline import Pipeline
|
||||
from ix.pipeline.reliability_step import ReliabilityStep
|
||||
from ix.pipeline.response_handler_step import ResponseHandlerStep
|
||||
from ix.pipeline.setup_step import SetupStep
|
||||
from ix.ui import build_router as build_ui_router
|
||||
from ix.ui.routes import STATIC_DIR as UI_STATIC_DIR
|
||||
|
||||
|
||||
def build_pipeline(
|
||||
genai: GenAIClient, ocr: OCRClient, cfg: AppConfig
|
||||
) -> Pipeline:
|
||||
"""Assemble the production :class:`Pipeline` with injected clients.
|
||||
|
||||
Kept as a module-level helper so tests that want to exercise the
|
||||
production wiring (without running the worker) can call it directly.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ix.ingestion import FetchConfig
|
||||
|
||||
return Pipeline(
|
||||
steps=[
|
||||
SetupStep(
|
||||
tmp_dir=Path(cfg.tmp_dir),
|
||||
fetch_config=FetchConfig(
|
||||
connect_timeout_s=float(cfg.file_connect_timeout_seconds),
|
||||
read_timeout_s=float(cfg.file_read_timeout_seconds),
|
||||
max_bytes=cfg.file_max_bytes,
|
||||
),
|
||||
),
|
||||
OCRStep(ocr_client=ocr),
|
||||
GenAIStep(genai_client=genai),
|
||||
ReliabilityStep(),
|
||||
ResponseHandlerStep(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _make_ollama_probe(
|
||||
genai: GenAIClient, cfg: AppConfig
|
||||
) -> Callable[[], Literal["ok", "degraded", "fail"]]:
|
||||
"""Adapter: async ``selfcheck`` → sync callable the route expects.
|
||||
|
||||
Always drives the coroutine on a throwaway event loop in a separate
|
||||
thread. This keeps the behavior identical whether the caller holds an
|
||||
event loop (FastAPI request) or doesn't (a CLI tool), and avoids the
|
||||
``asyncio.run`` vs. already-running-loop footgun.
|
||||
"""
|
||||
|
||||
def probe() -> Literal["ok", "degraded", "fail"]:
|
||||
if not hasattr(genai, "selfcheck"):
|
||||
return "ok" # fake client — nothing to probe.
|
||||
return _run_async_sync(
|
||||
lambda: genai.selfcheck(expected_model=cfg.default_model), # type: ignore[attr-defined]
|
||||
fallback="fail",
|
||||
)
|
||||
|
||||
return probe
|
||||
|
||||
|
||||
def _make_ocr_probe(ocr: OCRClient) -> Callable[[], Literal["ok", "fail"]]:
|
||||
def probe() -> Literal["ok", "fail"]:
|
||||
if not hasattr(ocr, "selfcheck"):
|
||||
return "ok" # fake — nothing to probe.
|
||||
return _run_async_sync(
|
||||
lambda: ocr.selfcheck(), # type: ignore[attr-defined]
|
||||
fallback="fail",
|
||||
)
|
||||
|
||||
return probe
|
||||
|
||||
|
||||
def _make_ocr_gpu_probe(ocr: OCRClient) -> Callable[[], bool | None]:
|
||||
"""Adapter: read the OCR client's recorded ``gpu_available`` attribute.
|
||||
|
||||
The attribute is set by :meth:`SuryaOCRClient.warm_up` on first load.
|
||||
Returns ``None`` when the client has no such attribute (e.g. FakeOCRClient
|
||||
in test mode) or warm_up hasn't happened yet. Never raises.
|
||||
"""
|
||||
|
||||
def probe() -> bool | None:
|
||||
return getattr(ocr, "gpu_available", None)
|
||||
|
||||
return probe
|
||||
|
||||
|
||||
def _run_async_sync(make_coro, *, fallback: str) -> str: # type: ignore[no-untyped-def]
|
||||
"""Run ``make_coro()`` on a fresh loop in a thread; return its result.
|
||||
|
||||
The thread owns its own event loop so the caller's loop (if any) keeps
|
||||
running. Any exception collapses to ``fallback``.
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
||||
result: dict[str, object] = {}
|
||||
|
||||
def _runner() -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
result["value"] = loop.run_until_complete(make_coro())
|
||||
except Exception as exc: # any error collapses to fallback
|
||||
result["error"] = exc
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
t = threading.Thread(target=_runner)
|
||||
t.start()
|
||||
t.join()
|
||||
if "error" in result or "value" not in result:
|
||||
return fallback
|
||||
return str(result["value"])
|
||||
|
||||
|
||||
def create_app(*, spawn_worker: bool = True) -> FastAPI:
|
||||
"""Construct the ASGI app.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spawn_worker:
|
||||
When True (default), the lifespan spawns the background worker task
|
||||
and the pg_queue listener. Integration tests that only exercise the
|
||||
REST adapter pass False so jobs pile up as ``pending`` and the tests
|
||||
can assert on their state without a racing worker mutating them.
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
cfg = get_config()
|
||||
|
||||
# Build the clients once per process. The worker's pipeline
|
||||
# factory closes over these so every job runs through the same
|
||||
# Ollama/Surya instance (Surya's predictors are heavy; re-loading
|
||||
# them per job would be catastrophic).
|
||||
genai_client = make_genai_client(cfg)
|
||||
ocr_client = make_ocr_client(cfg)
|
||||
|
||||
# Override the route-level probe DI so /healthz reflects the
|
||||
# actual clients. Tests that want canned probes can still override
|
||||
# ``get_probes`` at the TestClient layer.
|
||||
_app.dependency_overrides.setdefault(
|
||||
get_probes,
|
||||
lambda: Probes(
|
||||
ollama=_make_ollama_probe(genai_client, cfg),
|
||||
ocr=_make_ocr_probe(ocr_client),
|
||||
ocr_gpu=_make_ocr_gpu_probe(ocr_client),
|
||||
),
|
||||
)
|
||||
|
||||
worker_task = None
|
||||
listener = None
|
||||
if spawn_worker:
|
||||
try:
|
||||
from ix.adapters.pg_queue.listener import (
|
||||
PgQueueListener,
|
||||
asyncpg_dsn_from_sqlalchemy_url,
|
||||
)
|
||||
|
||||
listener = PgQueueListener(
|
||||
dsn=asyncpg_dsn_from_sqlalchemy_url(cfg.postgres_url)
|
||||
)
|
||||
await listener.start()
|
||||
except Exception:
|
||||
listener = None
|
||||
|
||||
try:
|
||||
worker_task = await _spawn_production_worker(
|
||||
cfg, genai_client, ocr_client, listener
|
||||
)
|
||||
except Exception:
|
||||
worker_task = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if worker_task is not None:
|
||||
worker_task.cancel()
|
||||
with suppress(Exception):
|
||||
await worker_task
|
||||
if listener is not None:
|
||||
with suppress(Exception):
|
||||
await listener.stop()
|
||||
|
||||
app = FastAPI(lifespan=lifespan, title="infoxtractor", version="0.1.0")
|
||||
app.include_router(rest_router)
|
||||
# Browser UI — additive, never touches the REST paths above.
|
||||
app.include_router(build_ui_router())
|
||||
# Static assets for the UI. CDN-only for MVP so the directory is
|
||||
# essentially empty, but the mount must exist so relative asset
|
||||
# URLs resolve cleanly.
|
||||
app.mount(
|
||||
"/ui/static",
|
||||
StaticFiles(directory=str(UI_STATIC_DIR)),
|
||||
name="ui-static",
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
async def _spawn_production_worker(
|
||||
cfg: AppConfig,
|
||||
genai: GenAIClient,
|
||||
ocr: OCRClient,
|
||||
listener, # type: ignore[no-untyped-def]
|
||||
) -> asyncio.Task[None]:
|
||||
"""Spawn the background worker with a production pipeline factory."""
|
||||
|
||||
from ix.store.engine import get_session_factory
|
||||
from ix.worker.loop import Worker
|
||||
|
||||
def pipeline_factory() -> Pipeline:
|
||||
return build_pipeline(genai, ocr, cfg)
|
||||
|
||||
worker = Worker(
|
||||
session_factory=get_session_factory(),
|
||||
pipeline_factory=pipeline_factory,
|
||||
poll_interval_seconds=10.0,
|
||||
max_running_seconds=2 * cfg.pipeline_request_timeout_seconds,
|
||||
callback_timeout_seconds=cfg.callback_timeout_seconds,
|
||||
wait_for_work=listener.wait_for_work if listener is not None else None,
|
||||
)
|
||||
|
||||
stop = asyncio.Event()
|
||||
return asyncio.create_task(worker.run(stop))
|
||||
86
src/ix/config.py
Normal file
86
src/ix/config.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""Application configuration — loaded from ``IX_*`` env vars via pydantic-settings.
|
||||
|
||||
Spec §9 lists every tunable. This module is the single read-point for them;
|
||||
callers that need runtime config should go through :func:`get_config` rather
|
||||
than ``os.environ``. The LRU cache makes the first call materialise + validate
|
||||
the full config and every subsequent call return the same instance.
|
||||
|
||||
Cache-clearing is public (``get_config.cache_clear()``) because tests need to
|
||||
re-read after ``monkeypatch.setenv``. Production code never clears the cache.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Literal
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class AppConfig(BaseSettings):
|
||||
"""Typed view over the ``IX_*`` environment.
|
||||
|
||||
Field names drop the ``IX_`` prefix — pydantic-settings puts it back via
|
||||
``env_prefix``. Defaults match the spec exactly; do not change a default
|
||||
here without updating spec §9 in the same commit.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="IX_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# --- Job store ---
|
||||
# Defaults assume the ix container runs with `network_mode: host` and
|
||||
# reaches the shared `postgis` and `ollama` containers on loopback;
|
||||
# spec §11 / docker-compose.yml ship that configuration.
|
||||
postgres_url: str = (
|
||||
"postgresql+asyncpg://infoxtractor:<password>"
|
||||
"@127.0.0.1:5431/infoxtractor"
|
||||
)
|
||||
|
||||
# --- LLM backend ---
|
||||
ollama_url: str = "http://127.0.0.1:11434"
|
||||
default_model: str = "qwen3:14b"
|
||||
|
||||
# --- OCR ---
|
||||
ocr_engine: str = "surya"
|
||||
|
||||
# --- Pipeline behavior ---
|
||||
pipeline_worker_concurrency: int = 1
|
||||
pipeline_request_timeout_seconds: int = 2700
|
||||
genai_call_timeout_seconds: int = 1500
|
||||
render_max_pixels_per_page: int = 75_000_000
|
||||
|
||||
# --- File fetching ---
|
||||
tmp_dir: str = "/tmp/ix"
|
||||
file_max_bytes: int = 52_428_800
|
||||
file_connect_timeout_seconds: int = 10
|
||||
file_read_timeout_seconds: int = 30
|
||||
|
||||
# --- Transport / callbacks ---
|
||||
callback_timeout_seconds: int = 10
|
||||
|
||||
# --- Observability ---
|
||||
log_level: str = "INFO"
|
||||
|
||||
# --- Test / wiring mode ---
|
||||
# ``fake``: factories return FakeGenAIClient / FakeOCRClient and
|
||||
# ``/healthz`` probes report ok. CI sets this so the Forgejo runner
|
||||
# doesn't need access to Ollama or GPU-backed Surya. ``None`` (default)
|
||||
# means production wiring: real OllamaClient + SuryaOCRClient.
|
||||
test_mode: Literal["fake"] | None = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> AppConfig:
|
||||
"""Return the process-wide :class:`AppConfig` (materialise on first call).
|
||||
|
||||
Wrapped in ``lru_cache`` so config is parsed + validated once per process.
|
||||
Tests call ``get_config.cache_clear()`` between scenarios; nothing in
|
||||
production should touch the cache.
|
||||
"""
|
||||
|
||||
return AppConfig()
|
||||
|
|
@ -22,6 +22,11 @@ class FileRef(BaseModel):
|
|||
Used when the file URL needs authentication (e.g. Paperless ``Token``) or a
|
||||
tighter size cap than :envvar:`IX_FILE_MAX_BYTES`. Plain URLs that need no
|
||||
headers can stay as bare ``str`` values in :attr:`Context.files`.
|
||||
|
||||
``display_name`` is pure UI metadata — the pipeline never consults it for
|
||||
execution. When the UI uploads a PDF under a random ``{uuid}.pdf`` name on
|
||||
disk, it stashes the client-provided filename here so the browser can
|
||||
surface "your_statement.pdf" instead of "8f3a...pdf" back to the user.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
|
@ -29,6 +34,7 @@ class FileRef(BaseModel):
|
|||
url: str
|
||||
headers: dict[str, str] = Field(default_factory=dict)
|
||||
max_bytes: int | None = None
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
|
|
@ -83,6 +89,44 @@ class Options(BaseModel):
|
|||
provenance: ProvenanceOptions = Field(default_factory=ProvenanceOptions)
|
||||
|
||||
|
||||
class UseCaseFieldDef(BaseModel):
|
||||
"""One field in an ad-hoc, caller-defined extraction schema.
|
||||
|
||||
The UI (and any other caller that doesn't want to wait on a backend
|
||||
registry entry) ships one of these per desired output field. The pipeline
|
||||
builds a fresh Pydantic response class from the list on each request.
|
||||
|
||||
``choices`` only applies to ``type == "str"`` — it turns the field into a
|
||||
``Literal[*choices]``. For any other type the builder raises
|
||||
``IX_001_001``.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str # must be a valid Python identifier
|
||||
type: Literal["str", "int", "float", "decimal", "date", "datetime", "bool"]
|
||||
required: bool = False
|
||||
description: str | None = None
|
||||
choices: list[str] | None = None
|
||||
|
||||
|
||||
class InlineUseCase(BaseModel):
|
||||
"""Caller-defined use case bundled into the :class:`RequestIX`.
|
||||
|
||||
When present on a request, the pipeline builds the ``(Request, Response)``
|
||||
Pydantic class pair on the fly from :attr:`fields` and skips the
|
||||
registered use-case lookup. The registry-based ``use_case`` field is still
|
||||
required on the request for metrics/logging but becomes a free-form label.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
use_case_name: str
|
||||
system_prompt: str
|
||||
default_model: str | None = None
|
||||
fields: list[UseCaseFieldDef]
|
||||
|
||||
|
||||
class RequestIX(BaseModel):
|
||||
"""Top-level job request.
|
||||
|
||||
|
|
@ -90,6 +134,12 @@ class RequestIX(BaseModel):
|
|||
it; the REST adapter / pg-queue adapter populates it on insert. The field
|
||||
is kept here so the contract is closed-over-construction round-trips
|
||||
(e.g. when the worker re-hydrates a job out of the store).
|
||||
|
||||
When ``use_case_inline`` is present, the pipeline uses it verbatim to
|
||||
build an ad-hoc ``(Request, Response)`` class pair and skips the registry
|
||||
lookup; ``use_case`` becomes a free-form label (still required for
|
||||
metrics/logging). When ``use_case_inline`` is absent, ``use_case`` is
|
||||
looked up in :data:`ix.use_cases.REGISTRY` as before.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
|
@ -101,3 +151,4 @@ class RequestIX(BaseModel):
|
|||
context: Context
|
||||
options: Options = Field(default_factory=Options)
|
||||
callback_url: str | None = None
|
||||
use_case_inline: InlineUseCase | None = None
|
||||
|
|
|
|||
43
src/ix/genai/__init__.py
Normal file
43
src/ix/genai/__init__.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
"""GenAI subsystem: protocol + fake client + invocation-result dataclasses.
|
||||
|
||||
Real backends (Ollama, …) plug in behind :class:`GenAIClient`. The factory
|
||||
:func:`make_genai_client` picks between :class:`FakeGenAIClient` (for CI
|
||||
/ hermetic tests via ``IX_TEST_MODE=fake``) and :class:`OllamaClient`
|
||||
(production). Tests that want a real Ollama client anyway can call the
|
||||
constructor directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.config import AppConfig
|
||||
from ix.genai.client import GenAIClient, GenAIInvocationResult, GenAIUsage
|
||||
from ix.genai.fake import FakeGenAIClient
|
||||
from ix.genai.ollama_client import OllamaClient
|
||||
|
||||
|
||||
def make_genai_client(cfg: AppConfig) -> GenAIClient:
|
||||
"""Return the :class:`GenAIClient` configured for the current run.
|
||||
|
||||
When ``cfg.test_mode == "fake"`` the fake is returned; the pipeline
|
||||
callers are expected to override the injected client via DI if they
|
||||
want a non-default canned response. Otherwise a live
|
||||
:class:`OllamaClient` bound to ``cfg.ollama_url`` and the per-call
|
||||
timeout is returned.
|
||||
"""
|
||||
|
||||
if cfg.test_mode == "fake":
|
||||
return FakeGenAIClient(parsed=None)
|
||||
return OllamaClient(
|
||||
base_url=cfg.ollama_url,
|
||||
per_call_timeout_s=float(cfg.genai_call_timeout_seconds),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FakeGenAIClient",
|
||||
"GenAIClient",
|
||||
"GenAIInvocationResult",
|
||||
"GenAIUsage",
|
||||
"OllamaClient",
|
||||
"make_genai_client",
|
||||
]
|
||||
72
src/ix/genai/client.py
Normal file
72
src/ix/genai/client.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""GenAIClient Protocol + invocation-result dataclasses (spec §6.3).
|
||||
|
||||
Structural typing: any object with an async
|
||||
``invoke(request_kwargs, response_schema) -> GenAIInvocationResult``
|
||||
method satisfies the Protocol. :class:`~ix.pipeline.genai_step.GenAIStep`
|
||||
depends on the Protocol; swapping ``FakeGenAIClient`` in tests for
|
||||
``OllamaClient`` in prod stays a wiring change.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GenAIUsage:
|
||||
"""Token counters returned by the LLM backend.
|
||||
|
||||
Both fields default to 0 so fakes / degraded backends can omit them.
|
||||
"""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GenAIInvocationResult:
|
||||
"""One LLM call's full output.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
parsed:
|
||||
The Pydantic instance produced from the model's structured output
|
||||
(typed as ``Any`` because the concrete class is the response schema
|
||||
passed into :meth:`GenAIClient.invoke`).
|
||||
usage:
|
||||
Token usage counters. Fakes may return a zero-filled
|
||||
:class:`GenAIUsage`.
|
||||
model_name:
|
||||
Echo of the model that served the request. Written to
|
||||
``ix_result.meta_data['model_name']`` by
|
||||
:class:`~ix.pipeline.genai_step.GenAIStep`.
|
||||
"""
|
||||
|
||||
parsed: Any
|
||||
usage: GenAIUsage
|
||||
model_name: str
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class GenAIClient(Protocol):
|
||||
"""Async LLM backend with structured-output support.
|
||||
|
||||
Implementations accept an already-assembled ``request_kwargs`` dict
|
||||
(messages, model, format, etc.) and a Pydantic class describing the
|
||||
expected structured-output schema, and return a
|
||||
:class:`GenAIInvocationResult`.
|
||||
"""
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
request_kwargs: dict[str, Any],
|
||||
response_schema: type[BaseModel],
|
||||
) -> GenAIInvocationResult:
|
||||
"""Run the LLM; parse the response into ``response_schema``; return it."""
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["GenAIClient", "GenAIInvocationResult", "GenAIUsage"]
|
||||
61
src/ix/genai/fake.py
Normal file
61
src/ix/genai/fake.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""FakeGenAIClient — returns a canned :class:`GenAIInvocationResult`.
|
||||
|
||||
Used by every pipeline unit test to avoid booting Ollama. The
|
||||
``raise_on_call`` hook lets error-path tests exercise ``IX_002_000``-style
|
||||
code paths without needing a real network error.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ix.genai.client import GenAIInvocationResult, GenAIUsage
|
||||
|
||||
|
||||
class FakeGenAIClient:
|
||||
"""Satisfies :class:`~ix.genai.client.GenAIClient` structurally.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parsed:
|
||||
The pre-built model instance returned as ``result.parsed``.
|
||||
usage:
|
||||
Token usage counters. Defaults to a zero-filled
|
||||
:class:`GenAIUsage`.
|
||||
model_name:
|
||||
Echoed on the result. Defaults to ``"fake"``.
|
||||
raise_on_call:
|
||||
If set, :meth:`invoke` raises this exception instead of returning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parsed: Any,
|
||||
*,
|
||||
usage: GenAIUsage | None = None,
|
||||
model_name: str = "fake",
|
||||
raise_on_call: BaseException | None = None,
|
||||
) -> None:
|
||||
self._parsed = parsed
|
||||
self._usage = usage if usage is not None else GenAIUsage()
|
||||
self._model_name = model_name
|
||||
self._raise_on_call = raise_on_call
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
request_kwargs: dict[str, Any],
|
||||
response_schema: type[BaseModel],
|
||||
) -> GenAIInvocationResult:
|
||||
"""Return the canned result or raise the configured error."""
|
||||
if self._raise_on_call is not None:
|
||||
raise self._raise_on_call
|
||||
return GenAIInvocationResult(
|
||||
parsed=self._parsed,
|
||||
usage=self._usage,
|
||||
model_name=self._model_name,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["FakeGenAIClient"]
|
||||
340
src/ix/genai/ollama_client.py
Normal file
340
src/ix/genai/ollama_client.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
"""OllamaClient — real :class:`GenAIClient` implementation (spec §6 GenAIStep).
|
||||
|
||||
Wraps the Ollama ``/api/chat`` structured-output endpoint. Per spec:
|
||||
|
||||
* POST ``{base_url}/api/chat`` with ``format = <pydantic JSON schema>``,
|
||||
``stream = false``, and ``options`` carrying provider-neutral knobs
|
||||
(``temperature`` mapped, ``reasoning_effort`` dropped — Ollama ignores it).
|
||||
* Messages are passed through. Content-parts lists (``[{"type":"text",...}]``)
|
||||
are joined to a single string because MVP models (``gpt-oss:20b`` /
|
||||
``qwen2.5:32b``) don't accept native content-parts.
|
||||
* Per-call timeout is enforced via ``httpx``. A connection refusal, read
|
||||
timeout, or 5xx maps to ``IX_002_000``. A 2xx whose ``message.content`` is
|
||||
not valid JSON for the schema maps to ``IX_002_001``.
|
||||
|
||||
``selfcheck()`` targets ``/api/tags`` with a fixed 5 s timeout and is what
|
||||
``/healthz`` consumes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.genai.client import GenAIInvocationResult, GenAIUsage
|
||||
|
||||
_OLLAMA_TAGS_TIMEOUT_S: float = 5.0
|
||||
_BODY_SNIPPET_MAX_CHARS: int = 240
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Async Ollama backend satisfying :class:`~ix.genai.client.GenAIClient`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_url:
|
||||
Root URL of the Ollama server (e.g. ``http://127.0.0.1:11434``).
|
||||
Trailing slashes are stripped.
|
||||
per_call_timeout_s:
|
||||
Hard per-call timeout for ``/api/chat``. Spec default: 1500 s.
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, per_call_timeout_s: float) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._per_call_timeout_s = per_call_timeout_s
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
request_kwargs: dict[str, Any],
|
||||
response_schema: type[BaseModel],
|
||||
) -> GenAIInvocationResult:
|
||||
"""Run one structured-output chat call; parse into ``response_schema``."""
|
||||
|
||||
body = self._translate_request(request_kwargs, response_schema)
|
||||
url = f"{self._base_url}/api/chat"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self._per_call_timeout_s) as http:
|
||||
resp = await http.post(url, json=body)
|
||||
except httpx.HTTPError as exc:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=f"ollama {exc.__class__.__name__}: {exc}",
|
||||
) from exc
|
||||
except (ConnectionError, TimeoutError) as exc: # pragma: no cover - httpx wraps these
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=f"ollama {exc.__class__.__name__}: {exc}",
|
||||
) from exc
|
||||
|
||||
if resp.status_code >= 500:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=(
|
||||
f"ollama HTTP {resp.status_code}: "
|
||||
f"{resp.text[:_BODY_SNIPPET_MAX_CHARS]}"
|
||||
),
|
||||
)
|
||||
if resp.status_code >= 400:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=(
|
||||
f"ollama HTTP {resp.status_code}: "
|
||||
f"{resp.text[:_BODY_SNIPPET_MAX_CHARS]}"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
payload = resp.json()
|
||||
except ValueError as exc:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=f"ollama non-JSON body: {resp.text[:_BODY_SNIPPET_MAX_CHARS]}",
|
||||
) from exc
|
||||
|
||||
content = (payload.get("message") or {}).get("content") or ""
|
||||
json_blob = _extract_json_blob(content)
|
||||
try:
|
||||
parsed = response_schema.model_validate_json(json_blob)
|
||||
except ValidationError as exc:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_001,
|
||||
detail=(
|
||||
f"{response_schema.__name__}: {exc.__class__.__name__}: "
|
||||
f"body={content[:_BODY_SNIPPET_MAX_CHARS]}"
|
||||
),
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
# ``model_validate_json`` raises ValueError on invalid JSON (not
|
||||
# a ValidationError). Treat as structured-output failure.
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_001,
|
||||
detail=(
|
||||
f"{response_schema.__name__}: invalid JSON: "
|
||||
f"body={content[:_BODY_SNIPPET_MAX_CHARS]}"
|
||||
),
|
||||
) from exc
|
||||
|
||||
usage = GenAIUsage(
|
||||
prompt_tokens=int(payload.get("prompt_eval_count") or 0),
|
||||
completion_tokens=int(payload.get("eval_count") or 0),
|
||||
)
|
||||
model_name = str(payload.get("model") or request_kwargs.get("model") or "")
|
||||
return GenAIInvocationResult(parsed=parsed, usage=usage, model_name=model_name)
|
||||
|
||||
async def selfcheck(
|
||||
self, expected_model: str
|
||||
) -> Literal["ok", "degraded", "fail"]:
|
||||
"""Probe ``/api/tags`` for ``/healthz``.
|
||||
|
||||
``ok`` when the server answers 2xx and ``expected_model`` is listed;
|
||||
``degraded`` when reachable but the model is missing; ``fail``
|
||||
otherwise. Spec §5, §11.
|
||||
"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_OLLAMA_TAGS_TIMEOUT_S) as http:
|
||||
resp = await http.get(f"{self._base_url}/api/tags")
|
||||
except (httpx.HTTPError, ConnectionError, TimeoutError):
|
||||
return "fail"
|
||||
|
||||
if resp.status_code != 200:
|
||||
return "fail"
|
||||
|
||||
try:
|
||||
payload = resp.json()
|
||||
except ValueError:
|
||||
return "fail"
|
||||
|
||||
models = payload.get("models") or []
|
||||
names = {str(entry.get("name", "")) for entry in models}
|
||||
if expected_model in names:
|
||||
return "ok"
|
||||
return "degraded"
|
||||
|
||||
def _translate_request(
|
||||
self,
|
||||
request_kwargs: dict[str, Any],
|
||||
response_schema: type[BaseModel],
|
||||
) -> dict[str, Any]:
|
||||
"""Map provider-neutral kwargs to Ollama's /api/chat body.
|
||||
|
||||
Schema strategy for Ollama 0.11.8: we pass ``format="json"`` (loose
|
||||
JSON mode) and bake the Pydantic schema into a system message
|
||||
ahead of the caller's own system prompt. Rationale:
|
||||
|
||||
* The full Pydantic schema as ``format=<schema>`` crashes llama.cpp's
|
||||
structured-output implementation (SIGSEGV) on every non-trivial
|
||||
shape — ``anyOf`` / ``$ref`` / ``pattern`` all trigger it.
|
||||
* ``format="json"`` alone guarantees valid JSON but not the shape;
|
||||
models routinely return ``{}`` when not told what fields to emit.
|
||||
* Injecting the schema into the prompt is the cheapest way to
|
||||
get both: the model sees the expected shape explicitly, Pydantic
|
||||
validates the response at parse time (IX_002_001 on mismatch).
|
||||
|
||||
Non-Ollama ``GenAIClient`` impls can ignore this behaviour and use
|
||||
native structured-output (``response_format`` on OpenAI, etc.).
|
||||
"""
|
||||
|
||||
messages = self._translate_messages(
|
||||
list(request_kwargs.get("messages") or [])
|
||||
)
|
||||
messages = _inject_schema_system_message(messages, response_schema)
|
||||
body: dict[str, Any] = {
|
||||
"model": request_kwargs.get("model"),
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
# NOTE: format is deliberately omitted. `format="json"` made
|
||||
# reasoning models (qwen3) abort after emitting `{}` because the
|
||||
# constrained sampler terminated before the chain-of-thought
|
||||
# finished; `format=<schema>` segfaulted Ollama 0.11.8. Letting
|
||||
# the model stream freely and then extracting the trailing JSON
|
||||
# blob works for both reasoning and non-reasoning models.
|
||||
}
|
||||
|
||||
options: dict[str, Any] = {}
|
||||
if "temperature" in request_kwargs:
|
||||
options["temperature"] = request_kwargs["temperature"]
|
||||
# reasoning_effort intentionally dropped — Ollama doesn't support it.
|
||||
if options:
|
||||
body["options"] = options
|
||||
return body
|
||||
|
||||
@staticmethod
|
||||
def _translate_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Collapse content-parts lists into single strings for Ollama."""
|
||||
out: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts = [
|
||||
str(part.get("text", ""))
|
||||
for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
]
|
||||
new_content = "\n".join(text_parts)
|
||||
else:
|
||||
new_content = content
|
||||
out.append({**msg, "content": new_content})
|
||||
return out
|
||||
|
||||
|
||||
def _extract_json_blob(text: str) -> str:
|
||||
"""Return the outermost balanced JSON object in ``text``.
|
||||
|
||||
Reasoning models (qwen3, deepseek-r1) wrap their real answer in
|
||||
``<think>…</think>`` blocks. Other models sometimes prefix prose or
|
||||
fence the JSON in ```json``` code blocks. Finding the last balanced
|
||||
``{…}`` is the cheapest robust parse that works for all three shapes;
|
||||
a malformed response yields the full text and Pydantic catches it
|
||||
downstream as ``IX_002_001``.
|
||||
"""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return text
|
||||
depth = 0
|
||||
in_string = False
|
||||
escaped = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if in_string:
|
||||
if escaped:
|
||||
escaped = False
|
||||
elif ch == "\\":
|
||||
escaped = True
|
||||
elif ch == '"':
|
||||
in_string = False
|
||||
continue
|
||||
if ch == '"':
|
||||
in_string = True
|
||||
elif ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return text[start:]
|
||||
|
||||
|
||||
def _inject_schema_system_message(
|
||||
messages: list[dict[str, Any]],
|
||||
response_schema: type[BaseModel],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Prepend a system message that pins the expected JSON shape.
|
||||
|
||||
Ollama's ``format="json"`` mode guarantees valid JSON but not the
|
||||
field set or names. We emit the Pydantic schema as JSON and
|
||||
instruct the model to match it. If the caller already provides a
|
||||
system message, we prepend ours; otherwise ours becomes the first
|
||||
system turn.
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
schema_json = _json.dumps(
|
||||
_sanitise_schema_for_ollama(response_schema.model_json_schema()),
|
||||
indent=2,
|
||||
)
|
||||
guidance = (
|
||||
"Respond ONLY with a single JSON object matching this JSON Schema "
|
||||
"exactly. No prose, no code fences, no explanations. All top-level "
|
||||
"properties listed in `required` MUST be present. Use null for "
|
||||
"fields you cannot confidently extract. The JSON Schema:\n"
|
||||
f"{schema_json}"
|
||||
)
|
||||
return [{"role": "system", "content": guidance}, *messages]
|
||||
|
||||
|
||||
def _sanitise_schema_for_ollama(schema: Any) -> Any:
|
||||
"""Strip null branches from ``anyOf`` unions.
|
||||
|
||||
Ollama 0.11.8's llama.cpp structured-output implementation segfaults on
|
||||
Pydantic v2's standard Optional pattern::
|
||||
|
||||
{"anyOf": [{"type": "string"}, {"type": "null"}]}
|
||||
|
||||
We collapse any ``anyOf`` that includes a ``{"type": "null"}`` entry to
|
||||
its non-null branch — single branch becomes that branch inline; multiple
|
||||
branches keep the union without null. This only narrows what the LLM is
|
||||
*told* it may emit; Pydantic still validates the real response and can
|
||||
accept ``None`` at parse time if the field is ``Optional``.
|
||||
|
||||
Walk is recursive and structure-preserving. Other ``anyOf`` shapes (e.g.
|
||||
polymorphic unions without null) are left alone.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
cleaned: dict[str, Any] = {}
|
||||
for key, value in schema.items():
|
||||
if key == "anyOf" and isinstance(value, list):
|
||||
non_null = [
|
||||
_sanitise_schema_for_ollama(branch)
|
||||
for branch in value
|
||||
if not (isinstance(branch, dict) and branch.get("type") == "null")
|
||||
]
|
||||
if len(non_null) == 1:
|
||||
# Inline the single remaining branch; merge its keys into the
|
||||
# parent so siblings like ``default``/``title`` are preserved.
|
||||
only = non_null[0]
|
||||
if isinstance(only, dict):
|
||||
for ok, ov in only.items():
|
||||
cleaned.setdefault(ok, ov)
|
||||
else:
|
||||
cleaned[key] = non_null
|
||||
elif len(non_null) == 0:
|
||||
# Pathological: nothing left. Fall back to a permissive type.
|
||||
cleaned["type"] = "string"
|
||||
else:
|
||||
cleaned[key] = non_null
|
||||
else:
|
||||
cleaned[key] = _sanitise_schema_for_ollama(value)
|
||||
return cleaned
|
||||
if isinstance(schema, list):
|
||||
return [_sanitise_schema_for_ollama(item) for item in schema]
|
||||
return schema
|
||||
|
||||
|
||||
__all__ = ["OllamaClient"]
|
||||
27
src/ix/ingestion/__init__.py
Normal file
27
src/ix/ingestion/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Ingestion pipeline helpers: fetch → MIME-detect → build pages.
|
||||
|
||||
Three modules layered bottom-up:
|
||||
|
||||
* :mod:`ix.ingestion.fetch` — async HTTP(S) / ``file://`` downloader with
|
||||
incremental size caps and pluggable timeouts.
|
||||
* :mod:`ix.ingestion.mime` — byte-sniffing MIME detection + the
|
||||
MVP-supported MIME set.
|
||||
* :mod:`ix.ingestion.pages` — :class:`DocumentIngestor` that turns local
|
||||
files + raw texts into the flat :class:`~ix.contracts.Page` list the
|
||||
OCR step expects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.ingestion.fetch import FetchConfig, fetch_file
|
||||
from ix.ingestion.mime import SUPPORTED_MIMES, detect_mime, require_supported
|
||||
from ix.ingestion.pages import DocumentIngestor
|
||||
|
||||
__all__ = [
|
||||
"SUPPORTED_MIMES",
|
||||
"DocumentIngestor",
|
||||
"FetchConfig",
|
||||
"detect_mime",
|
||||
"fetch_file",
|
||||
"require_supported",
|
||||
]
|
||||
144
src/ix/ingestion/fetch.py
Normal file
144
src/ix/ingestion/fetch.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""Async file fetcher (spec §6.1).
|
||||
|
||||
Supports ``http(s)://`` URLs (via httpx with configurable connect/read
|
||||
timeouts and an incremental size cap) and ``file://`` URLs (read from
|
||||
local fs — used by the E2E fixture). Auth headers on the :class:`FileRef`
|
||||
pass through unchanged.
|
||||
|
||||
Every failure mode surfaces as :attr:`~ix.errors.IXErrorCode.IX_000_007`
|
||||
with the offending URL + cause in the ``detail`` slot so the caller log
|
||||
line is grep-friendly.
|
||||
|
||||
Env-driven defaults live in :mod:`ix.config` (Chunk 3). The caller injects
|
||||
a :class:`FetchConfig` — this module is purely mechanical.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from ix.contracts import FileRef
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FetchConfig:
|
||||
"""Per-fetch knobs injected by the caller.
|
||||
|
||||
``connect_timeout_s`` / ``read_timeout_s`` → httpx timeouts.
|
||||
``max_bytes`` is the pipeline-wide default cap; the per-file override
|
||||
on :attr:`~ix.contracts.FileRef.max_bytes` wins when lower.
|
||||
"""
|
||||
|
||||
connect_timeout_s: float
|
||||
read_timeout_s: float
|
||||
max_bytes: int
|
||||
|
||||
|
||||
def _effective_cap(file_ref: FileRef, cfg: FetchConfig) -> int:
|
||||
"""The smaller of the pipeline-wide cap and the per-file override."""
|
||||
if file_ref.max_bytes is None:
|
||||
return cfg.max_bytes
|
||||
return min(cfg.max_bytes, file_ref.max_bytes)
|
||||
|
||||
|
||||
def _safe_filename(url: str) -> str:
|
||||
"""Derive a readable filename for the scratch copy from the URL."""
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
candidate = Path(parsed.path).name or "download"
|
||||
# Strip anything that would escape the tmp dir.
|
||||
return candidate.replace("/", "_").replace("\\", "_")
|
||||
|
||||
|
||||
async def _fetch_http(file_ref: FileRef, dst: Path, cfg: FetchConfig) -> None:
|
||||
"""HTTP(S) download with incremental size cap. Raises IX_000_007 on any failure."""
|
||||
cap = _effective_cap(file_ref, cfg)
|
||||
timeout = httpx.Timeout(
|
||||
cfg.read_timeout_s,
|
||||
connect=cfg.connect_timeout_s,
|
||||
)
|
||||
try:
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout) as client,
|
||||
client.stream(
|
||||
"GET",
|
||||
file_ref.url,
|
||||
headers=file_ref.headers or None,
|
||||
) as response,
|
||||
):
|
||||
if response.status_code >= 300:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_000_007,
|
||||
detail=f"{file_ref.url}: HTTP {response.status_code}",
|
||||
)
|
||||
total = 0
|
||||
with dst.open("wb") as fh:
|
||||
async for chunk in response.aiter_bytes():
|
||||
total += len(chunk)
|
||||
if total > cap:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_000_007,
|
||||
detail=f"{file_ref.url}: size cap {cap} bytes exceeded",
|
||||
)
|
||||
fh.write(chunk)
|
||||
except IXException:
|
||||
raise
|
||||
except httpx.TimeoutException as exc:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_000_007,
|
||||
detail=f"{file_ref.url}: timeout ({exc.__class__.__name__})",
|
||||
) from exc
|
||||
except httpx.HTTPError as exc:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_000_007,
|
||||
detail=f"{file_ref.url}: {exc.__class__.__name__}: {exc}",
|
||||
) from exc
|
||||
|
||||
|
||||
def _fetch_file_scheme(file_ref: FileRef, dst: Path, cfg: FetchConfig) -> None:
|
||||
"""Local-path read via ``file://`` URL. Same failure-mode contract."""
|
||||
cap = _effective_cap(file_ref, cfg)
|
||||
src_path = Path(urllib.parse.urlparse(file_ref.url).path)
|
||||
if not src_path.exists():
|
||||
raise IXException(
|
||||
IXErrorCode.IX_000_007,
|
||||
detail=f"{file_ref.url}: file does not exist",
|
||||
)
|
||||
size = src_path.stat().st_size
|
||||
if size > cap:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_000_007,
|
||||
detail=f"{file_ref.url}: size cap {cap} bytes exceeded",
|
||||
)
|
||||
dst.write_bytes(src_path.read_bytes())
|
||||
|
||||
|
||||
async def fetch_file(file_ref: FileRef, tmp_dir: Path, cfg: FetchConfig) -> Path:
|
||||
"""Download / copy ``file_ref`` into ``tmp_dir`` and return the local path.
|
||||
|
||||
http(s) and file:// URLs both supported. Any fetch failure raises
|
||||
:class:`~ix.errors.IXException` with
|
||||
:attr:`~ix.errors.IXErrorCode.IX_000_007`.
|
||||
"""
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
scheme = urllib.parse.urlparse(file_ref.url).scheme.lower()
|
||||
dst = tmp_dir / _safe_filename(file_ref.url)
|
||||
|
||||
if scheme in ("http", "https"):
|
||||
await _fetch_http(file_ref, dst, cfg)
|
||||
elif scheme == "file":
|
||||
_fetch_file_scheme(file_ref, dst, cfg)
|
||||
else:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_000_007,
|
||||
detail=f"{file_ref.url}: unsupported URL scheme {scheme!r}",
|
||||
)
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
__all__ = ["FetchConfig", "fetch_file"]
|
||||
37
src/ix/ingestion/mime.py
Normal file
37
src/ix/ingestion/mime.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""MIME detection + supported-MIME gate (spec §6.1).
|
||||
|
||||
Bytes-only; URL extensions are ignored because callers (Paperless, …)
|
||||
may serve `/download` routes without a file suffix. ``python-magic``
|
||||
reads the file header and returns the canonical MIME.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import magic
|
||||
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
|
||||
SUPPORTED_MIMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/tiff",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def detect_mime(path: Path) -> str:
|
||||
"""Return the canonical MIME string for ``path`` (byte-sniffed)."""
|
||||
return magic.from_file(str(path), mime=True)
|
||||
|
||||
|
||||
def require_supported(mime: str) -> None:
|
||||
"""Raise :class:`~ix.errors.IXException` (``IX_000_005``) if ``mime`` is unsupported."""
|
||||
if mime not in SUPPORTED_MIMES:
|
||||
raise IXException(IXErrorCode.IX_000_005, detail=mime)
|
||||
|
||||
|
||||
__all__ = ["SUPPORTED_MIMES", "detect_mime", "require_supported"]
|
||||
118
src/ix/ingestion/pages.py
Normal file
118
src/ix/ingestion/pages.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""Turn downloaded files + raw texts into a flat :class:`Page` list (spec §6.1).
|
||||
|
||||
PDFs → one :class:`Page` per page via PyMuPDF, with a hard 100 pages/PDF
|
||||
cap (``IX_000_006``). Images → Pillow; multi-frame TIFFs yield one Page
|
||||
per frame. Texts → one zero-dimension Page each so the downstream OCR /
|
||||
GenAI steps can still cite them.
|
||||
|
||||
A parallel list of :class:`~ix.segmentation.PageMetadata` is returned so
|
||||
the pipeline (via :class:`SegmentIndex`) can resolve segment IDs back to
|
||||
``file_index`` anchors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import fitz # PyMuPDF
|
||||
from PIL import Image, ImageSequence
|
||||
|
||||
from ix.contracts import Page
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.segmentation import PageMetadata
|
||||
|
||||
_PDF_PAGE_CAP = 100
|
||||
|
||||
|
||||
class DocumentIngestor:
|
||||
"""Builds the flat Page list that feeds :class:`~ix.pipeline.ocr_step.OCRStep`.
|
||||
|
||||
No constructor args for MVP — the 100-page cap is a spec constant. If
|
||||
this needs to be tunable later, move it to a dataclass config.
|
||||
"""
|
||||
|
||||
def build_pages(
|
||||
self,
|
||||
files: list[tuple[Path, str]],
|
||||
texts: list[str],
|
||||
) -> tuple[list[Page], list[PageMetadata]]:
|
||||
"""Return ``(pages, metas)`` in insertion order.
|
||||
|
||||
``files`` is a list of ``(local_path, mime_type)`` tuples; mimes
|
||||
have already been validated by :func:`ix.ingestion.mime.require_supported`.
|
||||
"""
|
||||
pages: list[Page] = []
|
||||
metas: list[PageMetadata] = []
|
||||
|
||||
for file_index, (path, mime) in enumerate(files):
|
||||
if mime == "application/pdf":
|
||||
self._extend_with_pdf(path, file_index, pages, metas)
|
||||
elif mime in ("image/png", "image/jpeg", "image/tiff"):
|
||||
self._extend_with_image(path, file_index, pages, metas)
|
||||
else: # pragma: no cover - defensive; require_supported should gate upstream
|
||||
raise IXException(IXErrorCode.IX_000_005, detail=mime)
|
||||
|
||||
for _ in texts:
|
||||
# Text-backed pages are zero-dim; they exist so the GenAIStep
|
||||
# can merge their content into the prompt alongside OCR.
|
||||
pages.append(
|
||||
Page(
|
||||
page_no=len(pages) + 1,
|
||||
width=0.0,
|
||||
height=0.0,
|
||||
lines=[],
|
||||
)
|
||||
)
|
||||
metas.append(PageMetadata(file_index=None))
|
||||
|
||||
return pages, metas
|
||||
|
||||
def _extend_with_pdf(
|
||||
self,
|
||||
path: Path,
|
||||
file_index: int,
|
||||
pages: list[Page],
|
||||
metas: list[PageMetadata],
|
||||
) -> None:
|
||||
doc = fitz.open(str(path))
|
||||
try:
|
||||
if doc.page_count > _PDF_PAGE_CAP:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_000_006,
|
||||
detail=f"{path}: {doc.page_count} pages (cap {_PDF_PAGE_CAP})",
|
||||
)
|
||||
for page in doc:
|
||||
rect = page.rect
|
||||
pages.append(
|
||||
Page(
|
||||
page_no=len(pages) + 1,
|
||||
width=float(rect.width),
|
||||
height=float(rect.height),
|
||||
lines=[],
|
||||
)
|
||||
)
|
||||
metas.append(PageMetadata(file_index=file_index))
|
||||
finally:
|
||||
doc.close()
|
||||
|
||||
def _extend_with_image(
|
||||
self,
|
||||
path: Path,
|
||||
file_index: int,
|
||||
pages: list[Page],
|
||||
metas: list[PageMetadata],
|
||||
) -> None:
|
||||
with Image.open(path) as img:
|
||||
for frame in ImageSequence.Iterator(img):
|
||||
pages.append(
|
||||
Page(
|
||||
page_no=len(pages) + 1,
|
||||
width=float(frame.width),
|
||||
height=float(frame.height),
|
||||
lines=[],
|
||||
)
|
||||
)
|
||||
metas.append(PageMetadata(file_index=file_index))
|
||||
|
||||
|
||||
__all__ = ["DocumentIngestor"]
|
||||
34
src/ix/ocr/__init__.py
Normal file
34
src/ix/ocr/__init__.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""OCR subsystem: protocol + fake + real Surya client + factory.
|
||||
|
||||
Real engines (Surya today, Azure DI / AWS Textract … deferred) plug in
|
||||
behind :class:`OCRClient`. The factory :func:`make_ocr_client` picks
|
||||
between :class:`FakeOCRClient` (when ``IX_TEST_MODE=fake``) and
|
||||
:class:`SuryaOCRClient` (production). Unknown engine names raise so a
|
||||
typo'd ``IX_OCR_ENGINE`` surfaces at startup, not later.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.config import AppConfig
|
||||
from ix.contracts.response import OCRDetails, OCRResult
|
||||
from ix.ocr.client import OCRClient
|
||||
from ix.ocr.fake import FakeOCRClient
|
||||
from ix.ocr.surya_client import SuryaOCRClient
|
||||
|
||||
|
||||
def make_ocr_client(cfg: AppConfig) -> OCRClient:
|
||||
"""Return the :class:`OCRClient` configured for the current run."""
|
||||
|
||||
if cfg.test_mode == "fake":
|
||||
return FakeOCRClient(canned=OCRResult(result=OCRDetails()))
|
||||
if cfg.ocr_engine == "surya":
|
||||
return SuryaOCRClient()
|
||||
raise ValueError(f"Unknown ocr_engine: {cfg.ocr_engine!r}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FakeOCRClient",
|
||||
"OCRClient",
|
||||
"SuryaOCRClient",
|
||||
"make_ocr_client",
|
||||
]
|
||||
50
src/ix/ocr/client.py
Normal file
50
src/ix/ocr/client.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""OCRClient Protocol (spec §6.2).
|
||||
|
||||
Structural typing: any object with an async ``ocr(pages) -> OCRResult``
|
||||
method satisfies the Protocol. :class:`~ix.pipeline.ocr_step.OCRStep`
|
||||
depends on the Protocol, not a concrete class, so swapping engines
|
||||
(``FakeOCRClient`` in tests, ``SuryaOCRClient`` in prod) stays a wiring
|
||||
change at the app factory.
|
||||
|
||||
Per-page source location (``files`` + ``page_metadata``) flows in as
|
||||
optional kwargs: fakes ignore them; the real
|
||||
:class:`~ix.ocr.surya_client.SuryaOCRClient` uses them to render each
|
||||
page's pixels back from disk. Keeping these optional lets unit tests stay
|
||||
pages-only while production wiring (Task 4.3) plumbs through the real
|
||||
filesystem handles.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from ix.contracts import OCRResult, Page
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OCRClient(Protocol):
|
||||
"""Async OCR backend.
|
||||
|
||||
Implementations receive the flat page list the pipeline built in
|
||||
:class:`~ix.pipeline.setup_step.SetupStep` and return an
|
||||
:class:`~ix.contracts.OCRResult` with one :class:`~ix.contracts.Page`
|
||||
per input page (in the same order).
|
||||
"""
|
||||
|
||||
async def ocr(
|
||||
self,
|
||||
pages: list[Page],
|
||||
*,
|
||||
files: list[tuple[Path, str]] | None = None,
|
||||
page_metadata: list[Any] | None = None,
|
||||
) -> OCRResult:
|
||||
"""Run OCR over the input pages; return the structured result.
|
||||
|
||||
``files`` and ``page_metadata`` are optional for hermetic tests;
|
||||
real engines that need to re-render from disk read them.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["OCRClient"]
|
||||
49
src/ix/ocr/fake.py
Normal file
49
src/ix/ocr/fake.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
"""FakeOCRClient — returns a canned :class:`OCRResult` for hermetic tests.
|
||||
|
||||
Used by every pipeline unit test to avoid booting Surya / CUDA. The
|
||||
``raise_on_call`` hook lets error-path tests exercise ``IX_002_000``-style
|
||||
code paths without needing to forge network errors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.contracts import OCRResult, Page
|
||||
|
||||
|
||||
class FakeOCRClient:
|
||||
"""Satisfies :class:`~ix.ocr.client.OCRClient` structurally.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
canned:
|
||||
The :class:`OCRResult` to return from every :meth:`ocr` call.
|
||||
raise_on_call:
|
||||
If set, :meth:`ocr` raises this exception instead of returning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
canned: OCRResult,
|
||||
*,
|
||||
raise_on_call: BaseException | None = None,
|
||||
) -> None:
|
||||
self._canned = canned
|
||||
self._raise_on_call = raise_on_call
|
||||
|
||||
async def ocr(
|
||||
self,
|
||||
pages: list[Page],
|
||||
**_kwargs: object,
|
||||
) -> OCRResult:
|
||||
"""Return the canned result or raise the configured error.
|
||||
|
||||
Accepts (and ignores) any keyword args the production Protocol may
|
||||
carry — keeps the fake swappable for :class:`SuryaOCRClient` at
|
||||
call sites that pass ``files`` / ``page_metadata``.
|
||||
"""
|
||||
if self._raise_on_call is not None:
|
||||
raise self._raise_on_call
|
||||
return self._canned
|
||||
|
||||
|
||||
__all__ = ["FakeOCRClient"]
|
||||
252
src/ix/ocr/surya_client.py
Normal file
252
src/ix/ocr/surya_client.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
"""SuryaOCRClient — real :class:`OCRClient` backed by ``surya-ocr``.
|
||||
|
||||
Per spec §6.2: the MVP OCR engine. Runs Surya's detection + recognition
|
||||
predictors over per-page PIL images rendered from the downloaded sources
|
||||
(PDFs via PyMuPDF, images via Pillow).
|
||||
|
||||
Design choices:
|
||||
|
||||
* **Lazy model loading.** ``__init__`` is cheap; the heavy predictors are
|
||||
built on first :meth:`ocr` / :meth:`selfcheck` / explicit :meth:`warm_up`.
|
||||
This keeps FastAPI's lifespan predictable — ops can decide whether to
|
||||
pay the load cost up front or on first request.
|
||||
* **Device is Surya's default.** CUDA on the prod box, MPS on M-series Macs.
|
||||
We deliberately don't pin.
|
||||
* **No text-token reuse from PyMuPDF.** The cross-check against Paperless'
|
||||
Tesseract output (ReliabilityStep's ``text_agreement``) is only meaningful
|
||||
with a truly independent OCR pass, so we always render-and-recognize
|
||||
even for PDFs that carry embedded text.
|
||||
|
||||
The ``surya-ocr`` package pulls torch + heavy model deps, so it's kept
|
||||
behind the ``[ocr]`` extra. All Surya imports are deferred into
|
||||
:meth:`warm_up` so running the unit tests (which patch the predictors)
|
||||
doesn't require the package to be installed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from ix.contracts import Line, OCRDetails, OCRResult, Page
|
||||
from ix.segmentation import PageMetadata
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from PIL import Image as PILImage
|
||||
|
||||
|
||||
class SuryaOCRClient:
|
||||
"""Surya-backed OCR engine.
|
||||
|
||||
Attributes are created lazily by :meth:`warm_up`. The unit tests inject
|
||||
mocks directly onto ``_recognition_predictor`` / ``_detection_predictor``
|
||||
to avoid the Surya import chain.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._recognition_predictor: Any = None
|
||||
self._detection_predictor: Any = None
|
||||
# ``None`` until warm_up() has run at least once. After that it's the
|
||||
# observed value of ``torch.cuda.is_available()`` at load time. We
|
||||
# cache it on the instance so ``/healthz`` / the UI can surface a
|
||||
# CPU-mode warning without re-probing torch each request.
|
||||
self.gpu_available: bool | None = None
|
||||
|
||||
def warm_up(self) -> None:
|
||||
"""Load the detection + recognition predictors. Idempotent.
|
||||
|
||||
Called automatically on the first :meth:`ocr` / :meth:`selfcheck`,
|
||||
or explicitly from the app lifespan to front-load the cost.
|
||||
"""
|
||||
if (
|
||||
self._recognition_predictor is not None
|
||||
and self._detection_predictor is not None
|
||||
):
|
||||
return
|
||||
|
||||
# Deferred imports: only reachable when the optional [ocr] extra is
|
||||
# installed. Keeping them inside the method so base-install unit
|
||||
# tests (which patch the predictors) don't need surya on sys.path.
|
||||
from surya.detection import DetectionPredictor # type: ignore[import-not-found]
|
||||
from surya.foundation import FoundationPredictor # type: ignore[import-not-found]
|
||||
from surya.recognition import RecognitionPredictor # type: ignore[import-not-found]
|
||||
|
||||
foundation = FoundationPredictor()
|
||||
self._recognition_predictor = RecognitionPredictor(foundation)
|
||||
self._detection_predictor = DetectionPredictor()
|
||||
|
||||
# Best-effort CUDA probe — only after predictors loaded cleanly so we
|
||||
# know torch is fully importable. ``torch`` is a Surya transitive
|
||||
# dependency so if we got this far it's on sys.path. We swallow any
|
||||
# exception to keep warm_up() sturdy: the attribute stays None and the
|
||||
# UI falls back to "unknown" gracefully.
|
||||
try:
|
||||
import torch # type: ignore[import-not-found]
|
||||
|
||||
self.gpu_available = bool(torch.cuda.is_available())
|
||||
except Exception:
|
||||
self.gpu_available = None
|
||||
|
||||
async def ocr(
|
||||
self,
|
||||
pages: list[Page],
|
||||
*,
|
||||
files: list[tuple[Path, str]] | None = None,
|
||||
page_metadata: list[Any] | None = None,
|
||||
) -> OCRResult:
|
||||
"""Render each input page, run Surya, translate back to contracts."""
|
||||
self.warm_up()
|
||||
|
||||
images = self._render_pages(pages, files, page_metadata)
|
||||
|
||||
# Surya is blocking — run it off the event loop.
|
||||
loop = asyncio.get_running_loop()
|
||||
surya_results = await loop.run_in_executor(
|
||||
None, self._run_recognition, images
|
||||
)
|
||||
|
||||
out_pages: list[Page] = []
|
||||
all_text_fragments: list[str] = []
|
||||
for input_page, surya_result in zip(pages, surya_results, strict=True):
|
||||
lines: list[Line] = []
|
||||
for tl in getattr(surya_result, "text_lines", []) or []:
|
||||
flat = self._flatten_polygon(getattr(tl, "polygon", None))
|
||||
text = getattr(tl, "text", None)
|
||||
lines.append(Line(text=text, bounding_box=flat))
|
||||
if text:
|
||||
all_text_fragments.append(text)
|
||||
out_pages.append(
|
||||
Page(
|
||||
page_no=input_page.page_no,
|
||||
width=input_page.width,
|
||||
height=input_page.height,
|
||||
angle=input_page.angle,
|
||||
unit=input_page.unit,
|
||||
lines=lines,
|
||||
)
|
||||
)
|
||||
|
||||
details = OCRDetails(
|
||||
text="\n".join(all_text_fragments) if all_text_fragments else None,
|
||||
pages=out_pages,
|
||||
)
|
||||
return OCRResult(result=details, meta_data={"engine": "surya"})
|
||||
|
||||
async def selfcheck(self) -> Literal["ok", "fail"]:
|
||||
"""Run the predictors on a 1x1 image to confirm the stack works."""
|
||||
try:
|
||||
self.warm_up()
|
||||
except Exception:
|
||||
return "fail"
|
||||
|
||||
try:
|
||||
from PIL import Image as PILImageRuntime
|
||||
|
||||
img = PILImageRuntime.new("RGB", (1, 1), color="white")
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._run_recognition, [img])
|
||||
except Exception:
|
||||
return "fail"
|
||||
return "ok"
|
||||
|
||||
def _run_recognition(self, images: list[PILImage.Image]) -> list[Any]:
|
||||
"""Invoke the recognition predictor. Kept tiny for threadpool offload."""
|
||||
return list(
|
||||
self._recognition_predictor(
|
||||
images, det_predictor=self._detection_predictor
|
||||
)
|
||||
)
|
||||
|
||||
def _render_pages(
|
||||
self,
|
||||
pages: list[Page],
|
||||
files: list[tuple[Path, str]] | None,
|
||||
page_metadata: list[Any] | None,
|
||||
) -> list[PILImage.Image]:
|
||||
"""Render each input :class:`Page` to a PIL image.
|
||||
|
||||
We walk pages + page_metadata in lockstep so we know which source
|
||||
file each page came from and (for PDFs) what page-index to render.
|
||||
Text-only pages (``file_index is None``) get a blank 1x1 placeholder
|
||||
so Surya returns an empty result and downstream code still gets one
|
||||
entry per input page.
|
||||
"""
|
||||
from PIL import Image as PILImageRuntime
|
||||
|
||||
metas: list[PageMetadata] = list(page_metadata or [])
|
||||
file_records: list[tuple[Path, str]] = list(files or [])
|
||||
|
||||
# Per-file lazy PDF openers so we don't re-open across pages.
|
||||
pdf_docs: dict[int, Any] = {}
|
||||
|
||||
# Per-file running page-within-file counter. For PDFs we emit one
|
||||
# entry per PDF page in order; ``pages`` was built the same way by
|
||||
# DocumentIngestor, so a parallel counter reconstructs the mapping.
|
||||
per_file_cursor: dict[int, int] = {}
|
||||
|
||||
rendered: list[PILImage.Image] = []
|
||||
try:
|
||||
for idx, _page in enumerate(pages):
|
||||
meta = metas[idx] if idx < len(metas) else PageMetadata()
|
||||
file_index = meta.file_index
|
||||
if file_index is None or file_index >= len(file_records):
|
||||
# Text-only page — placeholder image; Surya returns empty.
|
||||
rendered.append(
|
||||
PILImageRuntime.new("RGB", (1, 1), color="white")
|
||||
)
|
||||
continue
|
||||
|
||||
local_path, mime = file_records[file_index]
|
||||
if mime == "application/pdf":
|
||||
doc = pdf_docs.get(file_index)
|
||||
if doc is None:
|
||||
import fitz # PyMuPDF
|
||||
|
||||
doc = fitz.open(str(local_path))
|
||||
pdf_docs[file_index] = doc
|
||||
pdf_page_no = per_file_cursor.get(file_index, 0)
|
||||
per_file_cursor[file_index] = pdf_page_no + 1
|
||||
pdf_page = doc.load_page(pdf_page_no)
|
||||
pix = pdf_page.get_pixmap(dpi=200)
|
||||
img = PILImageRuntime.frombytes(
|
||||
"RGB", (pix.width, pix.height), pix.samples
|
||||
)
|
||||
rendered.append(img)
|
||||
elif mime in ("image/png", "image/jpeg", "image/tiff"):
|
||||
frame_no = per_file_cursor.get(file_index, 0)
|
||||
per_file_cursor[file_index] = frame_no + 1
|
||||
img = PILImageRuntime.open(local_path)
|
||||
# Handle multi-frame (TIFF) — seek to the right frame.
|
||||
with contextlib.suppress(EOFError):
|
||||
img.seek(frame_no)
|
||||
rendered.append(img.convert("RGB"))
|
||||
else: # pragma: no cover - ingestor already rejected
|
||||
rendered.append(
|
||||
PILImageRuntime.new("RGB", (1, 1), color="white")
|
||||
)
|
||||
finally:
|
||||
for doc in pdf_docs.values():
|
||||
with contextlib.suppress(Exception):
|
||||
doc.close()
|
||||
return rendered
|
||||
|
||||
@staticmethod
|
||||
def _flatten_polygon(polygon: Any) -> list[float]:
|
||||
"""Flatten ``[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]`` → 8-float list.
|
||||
|
||||
Surya emits 4 quad corners. The spec wants 8 raw-pixel coords so
|
||||
downstream provenance normalisation can consume them directly.
|
||||
"""
|
||||
if not polygon:
|
||||
return []
|
||||
flat: list[float] = []
|
||||
for point in polygon:
|
||||
if isinstance(point, (list, tuple)) and len(point) >= 2:
|
||||
flat.append(float(point[0]))
|
||||
flat.append(float(point[1]))
|
||||
return flat
|
||||
|
||||
|
||||
__all__ = ["SuryaOCRClient"]
|
||||
18
src/ix/pipeline/__init__.py
Normal file
18
src/ix/pipeline/__init__.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Pipeline orchestration: Step ABC, Pipeline runner, per-step Timer.
|
||||
|
||||
Concrete steps (Setup / OCR / GenAI / Reliability / ResponseHandler) live in
|
||||
sibling modules and are wired into a :class:`Pipeline` by the application
|
||||
factory. The pipeline core is transport-agnostic: it takes a
|
||||
:class:`~ix.contracts.RequestIX`, threads a shared
|
||||
:class:`~ix.contracts.ResponseIX` through every step, and returns the
|
||||
populated response. Timings land in ``response.metadata.timings``; the first
|
||||
:class:`~ix.errors.IXException` raised aborts the pipeline and is written to
|
||||
``response.error``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.pipeline.pipeline import Pipeline, Timer
|
||||
from ix.pipeline.step import Step
|
||||
|
||||
__all__ = ["Pipeline", "Step", "Timer"]
|
||||
216
src/ix/pipeline/genai_step.py
Normal file
216
src/ix/pipeline/genai_step.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""GenAIStep — assemble prompt, call LLM, map provenance (spec §6.3, §7, §9.2).
|
||||
|
||||
Runs after :class:`~ix.pipeline.ocr_step.OCRStep`. Builds the chat-style
|
||||
``request_kwargs`` (messages + model name), picks the structured-output
|
||||
schema (plain ``UseCaseResponse`` or a runtime
|
||||
``ProvenanceWrappedResponse(result=..., segment_citations=...)`` when
|
||||
provenance is on), hands both to the injected :class:`GenAIClient`, and
|
||||
writes the parsed payload onto ``response_ix.ix_result``.
|
||||
|
||||
When provenance is on, the LLM-emitted ``segment_citations`` flow into
|
||||
:func:`~ix.provenance.map_segment_refs_to_provenance` to build
|
||||
``response_ix.provenance``. The per-field reliability flags
|
||||
(``provenance_verified`` / ``text_agreement``) stay ``None`` here — they
|
||||
land in :class:`~ix.pipeline.reliability_step.ReliabilityStep`.
|
||||
|
||||
Failure modes:
|
||||
|
||||
* Network / timeout / non-2xx surfaced by the client → ``IX_002_000``.
|
||||
* :class:`pydantic.ValidationError` (structured output didn't match the
|
||||
schema) → ``IX_002_001``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationError, create_model
|
||||
|
||||
from ix.contracts import RequestIX, ResponseIX, SegmentCitation
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.genai.client import GenAIClient
|
||||
from ix.pipeline.step import Step
|
||||
from ix.provenance import map_segment_refs_to_provenance
|
||||
from ix.segmentation import SegmentIndex
|
||||
|
||||
# Verbatim from spec §9.2 (core-pipeline spec) — inserted after the
|
||||
# use-case system prompt when provenance is on.
|
||||
_CITATION_INSTRUCTION = (
|
||||
"For each extracted field, you must also populate the `segment_citations` list.\n"
|
||||
"Each entry maps one field to the document segments that were its source.\n"
|
||||
"Set `field_path` to the dot-separated JSON path of the field "
|
||||
"(e.g. 'result.invoice_number').\n"
|
||||
"Use two separate segment ID lists:\n"
|
||||
"- `value_segment_ids`: segment IDs whose text directly contains the extracted "
|
||||
"value (e.g. ['p1_l4'] for the line containing 'INV-001').\n"
|
||||
"- `context_segment_ids`: segment IDs for surrounding label or anchor text that "
|
||||
"helped you identify the field but does not contain the value itself "
|
||||
"(e.g. ['p1_l3'] for a label like 'Invoice Number:'). Leave empty if there is "
|
||||
"no distinct label.\n"
|
||||
"Only use segment IDs that appear in the document text.\n"
|
||||
"Omit fields for which you cannot identify a source segment."
|
||||
)
|
||||
|
||||
|
||||
class GenAIStep(Step):
|
||||
"""LLM extraction + (optional) provenance mapping."""
|
||||
|
||||
def __init__(self, genai_client: GenAIClient) -> None:
|
||||
self._client = genai_client
|
||||
|
||||
async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool:
|
||||
if request_ix.options.ocr.ocr_only:
|
||||
return False
|
||||
|
||||
ctx = response_ix.context
|
||||
ocr_text = (
|
||||
response_ix.ocr_result.result.text
|
||||
if response_ix.ocr_result is not None
|
||||
else None
|
||||
)
|
||||
texts = list(getattr(ctx, "texts", []) or []) if ctx is not None else []
|
||||
|
||||
if not (ocr_text and ocr_text.strip()) and not texts:
|
||||
raise IXException(IXErrorCode.IX_001_000)
|
||||
return True
|
||||
|
||||
async def process(
|
||||
self, request_ix: RequestIX, response_ix: ResponseIX
|
||||
) -> ResponseIX:
|
||||
ctx = response_ix.context
|
||||
assert ctx is not None, "SetupStep must populate response_ix.context"
|
||||
use_case_request: Any = getattr(ctx, "use_case_request", None)
|
||||
use_case_response_cls: type[BaseModel] = getattr(ctx, "use_case_response", None)
|
||||
assert use_case_request is not None and use_case_response_cls is not None
|
||||
|
||||
opts = request_ix.options
|
||||
provenance_on = opts.provenance.include_provenance
|
||||
|
||||
# 1. System prompt — use-case default + optional citation instruction.
|
||||
system_prompt = use_case_request.system_prompt
|
||||
if provenance_on:
|
||||
system_prompt = f"{system_prompt}\n\n{_CITATION_INSTRUCTION}"
|
||||
|
||||
# 2. User text — segment-tagged when provenance is on, else plain OCR + texts.
|
||||
user_text = self._build_user_text(response_ix, provenance_on)
|
||||
|
||||
# 3. Response schema — plain or wrapped.
|
||||
response_schema = self._resolve_response_schema(
|
||||
use_case_response_cls, provenance_on
|
||||
)
|
||||
|
||||
# 4. Model selection — request override → use-case default.
|
||||
model_name = (
|
||||
opts.gen_ai.gen_ai_model_name
|
||||
or getattr(use_case_request, "default_model", None)
|
||||
)
|
||||
|
||||
request_kwargs = {
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_text},
|
||||
],
|
||||
}
|
||||
|
||||
# 5. Call the backend, translate errors.
|
||||
try:
|
||||
result = await self._client.invoke(
|
||||
request_kwargs=request_kwargs,
|
||||
response_schema=response_schema,
|
||||
)
|
||||
except ValidationError as exc:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_001,
|
||||
detail=f"{use_case_response_cls.__name__}: {exc}",
|
||||
) from exc
|
||||
except (httpx.HTTPError, ConnectionError, TimeoutError) as exc:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=f"{model_name}: {exc.__class__.__name__}: {exc}",
|
||||
) from exc
|
||||
except IXException:
|
||||
raise
|
||||
|
||||
# 6. Split parsed output; write result + meta.
|
||||
if provenance_on:
|
||||
wrapped = result.parsed
|
||||
extraction: BaseModel = wrapped.result
|
||||
segment_citations: list[SegmentCitation] = list(
|
||||
getattr(wrapped, "segment_citations", []) or []
|
||||
)
|
||||
else:
|
||||
extraction = result.parsed
|
||||
segment_citations = []
|
||||
|
||||
response_ix.ix_result.result = extraction.model_dump(mode="json")
|
||||
response_ix.ix_result.meta_data = {
|
||||
"model_name": result.model_name,
|
||||
"token_usage": {
|
||||
"prompt_tokens": result.usage.prompt_tokens,
|
||||
"completion_tokens": result.usage.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
# 7. Provenance mapping — only the structural assembly. Reliability
|
||||
# flags get written in ReliabilityStep.
|
||||
if provenance_on:
|
||||
seg_idx = cast(SegmentIndex, getattr(ctx, "segment_index", None))
|
||||
if seg_idx is None:
|
||||
# No OCR was run (text-only request); skip provenance.
|
||||
response_ix.provenance = None
|
||||
else:
|
||||
response_ix.provenance = map_segment_refs_to_provenance(
|
||||
extraction_result={"result": response_ix.ix_result.result},
|
||||
segment_citations=segment_citations,
|
||||
segment_index=seg_idx,
|
||||
max_sources_per_field=opts.provenance.max_sources_per_field,
|
||||
min_confidence=0.0,
|
||||
include_bounding_boxes=True,
|
||||
source_type="value_and_context",
|
||||
)
|
||||
|
||||
return response_ix
|
||||
|
||||
def _build_user_text(self, response_ix: ResponseIX, provenance_on: bool) -> str:
|
||||
ctx = response_ix.context
|
||||
assert ctx is not None
|
||||
texts: list[str] = list(getattr(ctx, "texts", []) or [])
|
||||
seg_idx: SegmentIndex | None = getattr(ctx, "segment_index", None)
|
||||
|
||||
if provenance_on and seg_idx is not None:
|
||||
return seg_idx.to_prompt_text(context_texts=texts)
|
||||
|
||||
# Plain concat — OCR flat text + any extra paperless-style texts.
|
||||
parts: list[str] = []
|
||||
ocr_text = (
|
||||
response_ix.ocr_result.result.text
|
||||
if response_ix.ocr_result is not None
|
||||
else None
|
||||
)
|
||||
if ocr_text:
|
||||
parts.append(ocr_text)
|
||||
parts.extend(texts)
|
||||
return "\n\n".join(p for p in parts if p)
|
||||
|
||||
def _resolve_response_schema(
|
||||
self,
|
||||
use_case_response_cls: type[BaseModel],
|
||||
provenance_on: bool,
|
||||
) -> type[BaseModel]:
|
||||
if not provenance_on:
|
||||
return use_case_response_cls
|
||||
# Dynamic wrapper — one per call is fine; Pydantic caches the
|
||||
# generated JSON schema internally.
|
||||
return create_model(
|
||||
"ProvenanceWrappedResponse",
|
||||
result=(use_case_response_cls, ...),
|
||||
segment_citations=(
|
||||
list[SegmentCitation],
|
||||
Field(default_factory=list),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["GenAIStep"]
|
||||
95
src/ix/pipeline/ocr_step.py
Normal file
95
src/ix/pipeline/ocr_step.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""OCRStep — run OCR, inject page tags, build the SegmentIndex (spec §6.2).
|
||||
|
||||
Runs after :class:`~ix.pipeline.setup_step.SetupStep`. Three things
|
||||
happen:
|
||||
|
||||
1. Dispatch the flat page list to the injected :class:`OCRClient` and
|
||||
write the raw :class:`~ix.contracts.OCRResult` onto the response.
|
||||
2. Inject ``<page file="..." number="...">`` / ``</page>`` tag lines
|
||||
around each page's content so the GenAIStep can ground citations
|
||||
visually (spec §6.2c).
|
||||
3. When provenance is on, build a :class:`SegmentIndex` over the
|
||||
*non-tag* lines and stash it on the internal context.
|
||||
|
||||
Validation follows the triad-or-use_ocr rule from the spec: if any of
|
||||
``include_geometries`` / ``include_ocr_text`` / ``ocr_only`` is set but
|
||||
``context.files`` is empty, raise ``IX_000_004``. If OCR isn't wanted
|
||||
(text-only request), return ``False`` to skip the step silently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.contracts import Line, RequestIX, ResponseIX
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.ocr.client import OCRClient
|
||||
from ix.pipeline.step import Step
|
||||
from ix.segmentation import PageMetadata, SegmentIndex
|
||||
|
||||
|
||||
class OCRStep(Step):
|
||||
"""Inject-and-index OCR stage."""
|
||||
|
||||
def __init__(self, ocr_client: OCRClient) -> None:
|
||||
self._client = ocr_client
|
||||
|
||||
async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool:
|
||||
opts = request_ix.options.ocr
|
||||
ctx = response_ix.context
|
||||
files = list(getattr(ctx, "files", [])) if ctx is not None else []
|
||||
|
||||
ocr_artifacts_requested = (
|
||||
opts.include_geometries or opts.include_ocr_text or opts.ocr_only
|
||||
)
|
||||
if ocr_artifacts_requested and not files:
|
||||
raise IXException(IXErrorCode.IX_000_004)
|
||||
|
||||
if not files:
|
||||
return False
|
||||
|
||||
# OCR runs if use_ocr OR any of the artifact flags is set.
|
||||
return bool(opts.use_ocr or ocr_artifacts_requested)
|
||||
|
||||
async def process(
|
||||
self, request_ix: RequestIX, response_ix: ResponseIX
|
||||
) -> ResponseIX:
|
||||
ctx = response_ix.context
|
||||
assert ctx is not None, "SetupStep must populate response_ix.context"
|
||||
|
||||
pages = list(getattr(ctx, "pages", []))
|
||||
files = list(getattr(ctx, "files", []) or [])
|
||||
page_metadata = list(getattr(ctx, "page_metadata", []) or [])
|
||||
ocr_result = await self._client.ocr(
|
||||
pages, files=files, page_metadata=page_metadata
|
||||
)
|
||||
|
||||
# Inject page tags around each OCR page's content so the LLM can
|
||||
# cross-reference the visual anchor without a separate prompt hack.
|
||||
page_metadata: list[PageMetadata] = list(
|
||||
getattr(ctx, "page_metadata", []) or []
|
||||
)
|
||||
for idx, ocr_page in enumerate(ocr_result.result.pages):
|
||||
meta = page_metadata[idx] if idx < len(page_metadata) else PageMetadata()
|
||||
file_idx = meta.file_index if meta.file_index is not None else 0
|
||||
open_tag = Line(
|
||||
text=f'<page file="{file_idx}" number="{ocr_page.page_no}">',
|
||||
bounding_box=[],
|
||||
)
|
||||
close_tag = Line(text="</page>", bounding_box=[])
|
||||
ocr_page.lines = [open_tag, *ocr_page.lines, close_tag]
|
||||
|
||||
response_ix.ocr_result = ocr_result
|
||||
|
||||
# Build SegmentIndex only when provenance is on. Segment IDs
|
||||
# deliberately skip page-tag lines (see SegmentIndex.build).
|
||||
if request_ix.options.provenance.include_provenance:
|
||||
seg_idx = SegmentIndex.build(
|
||||
ocr_result=ocr_result,
|
||||
granularity="line",
|
||||
pages_metadata=page_metadata,
|
||||
)
|
||||
ctx.segment_index = seg_idx
|
||||
|
||||
return response_ix
|
||||
|
||||
|
||||
__all__ = ["OCRStep"]
|
||||
108
src/ix/pipeline/pipeline.py
Normal file
108
src/ix/pipeline/pipeline.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
"""Pipeline runner + Timer context manager (spec §4).
|
||||
|
||||
The runner threads a fresh :class:`~ix.contracts.ResponseIX` through every
|
||||
registered :class:`Step`, records per-step elapsed seconds in
|
||||
``response.metadata.timings`` (always — even for validated-out-or-raised
|
||||
steps, so the timeline is reconstructable from logs), and aborts on the
|
||||
first :class:`~ix.errors.IXException` by writing ``response.error`` and
|
||||
stopping the loop. Non-IX exceptions propagate — the job-store layer decides
|
||||
whether to swallow or surface them.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
from ix.contracts import Metadata, RequestIX, ResponseIX
|
||||
from ix.errors import IXException
|
||||
from ix.pipeline.step import Step
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Context manager that appends one timing entry to a list.
|
||||
|
||||
Example::
|
||||
|
||||
timings: list[dict[str, Any]] = []
|
||||
with Timer("setup", timings):
|
||||
... # work
|
||||
# timings == [{"step": "setup", "elapsed_seconds": 0.003}]
|
||||
|
||||
The entry is appended on ``__exit__`` regardless of whether the body
|
||||
raised — the timeline stays accurate even for failed steps.
|
||||
"""
|
||||
|
||||
def __init__(self, step_name: str, sink: list[dict[str, Any]]) -> None:
|
||||
self._step_name = step_name
|
||||
self._sink = sink
|
||||
self._start: float = 0.0
|
||||
|
||||
def __enter__(self) -> Timer:
|
||||
self._start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
tb: TracebackType | None,
|
||||
) -> None:
|
||||
elapsed = time.perf_counter() - self._start
|
||||
self._sink.append({"step": self._step_name, "elapsed_seconds": elapsed})
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Runs a fixed ordered list of :class:`Step` instances against one request.
|
||||
|
||||
The pipeline is stateless — constructing once at app-startup and calling
|
||||
:meth:`start` repeatedly is the intended usage pattern. Per-request state
|
||||
lives on the :class:`~ix.contracts.ResponseIX` the pipeline creates and
|
||||
threads through every step.
|
||||
"""
|
||||
|
||||
def __init__(self, steps: list[Step]) -> None:
|
||||
self._steps = list(steps)
|
||||
|
||||
async def start(self, request_ix: RequestIX) -> ResponseIX:
|
||||
"""Execute every step; return the populated :class:`ResponseIX`.
|
||||
|
||||
Flow:
|
||||
|
||||
1. Instantiate a fresh ``ResponseIX`` seeded with request correlation
|
||||
ids.
|
||||
2. For each step: time the call, run ``validate`` then (iff True)
|
||||
``process``. Append the timing entry. If either hook raises
|
||||
:class:`~ix.errors.IXException`, write ``response.error`` and
|
||||
stop. Non-IX exceptions propagate.
|
||||
"""
|
||||
response_ix = ResponseIX(
|
||||
use_case=request_ix.use_case,
|
||||
ix_client_id=request_ix.ix_client_id,
|
||||
request_id=request_ix.request_id,
|
||||
ix_id=request_ix.ix_id,
|
||||
metadata=Metadata(),
|
||||
)
|
||||
|
||||
for step in self._steps:
|
||||
with Timer(step.step_name, response_ix.metadata.timings):
|
||||
try:
|
||||
should_run = await step.validate(request_ix, response_ix)
|
||||
except IXException as exc:
|
||||
response_ix.error = str(exc)
|
||||
return response_ix
|
||||
|
||||
if not should_run:
|
||||
continue
|
||||
|
||||
try:
|
||||
response_ix = await step.process(request_ix, response_ix)
|
||||
except IXException as exc:
|
||||
response_ix.error = str(exc)
|
||||
return response_ix
|
||||
|
||||
return response_ix
|
||||
|
||||
|
||||
__all__ = ["Pipeline", "Timer"]
|
||||
56
src/ix/pipeline/reliability_step.py
Normal file
56
src/ix/pipeline/reliability_step.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""ReliabilityStep — writes provenance_verified + text_agreement (spec §6).
|
||||
|
||||
Runs after :class:`~ix.pipeline.genai_step.GenAIStep`. Skips entirely
|
||||
when provenance is off OR when no provenance data was built (OCR-skipped
|
||||
text-only request, for example). Otherwise delegates to
|
||||
:func:`~ix.provenance.apply_reliability_flags`, which mutates each
|
||||
:class:`~ix.contracts.FieldProvenance` in place and fills the two
|
||||
summary counters (``verified_fields``, ``text_agreement_fields``) on
|
||||
``quality_metrics``.
|
||||
|
||||
No own dispatch logic — everything interesting lives in the normalisers
|
||||
+ verifier modules and is unit-tested there.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ix.contracts import RequestIX, ResponseIX
|
||||
from ix.pipeline.step import Step
|
||||
from ix.provenance import apply_reliability_flags
|
||||
|
||||
|
||||
class ReliabilityStep(Step):
|
||||
"""Fills per-field reliability flags on ``response.provenance``."""
|
||||
|
||||
async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool:
|
||||
if not request_ix.options.provenance.include_provenance:
|
||||
return False
|
||||
return response_ix.provenance is not None
|
||||
|
||||
async def process(
|
||||
self, request_ix: RequestIX, response_ix: ResponseIX
|
||||
) -> ResponseIX:
|
||||
assert response_ix.provenance is not None # validate() guarantees
|
||||
|
||||
ctx = response_ix.context
|
||||
texts: list[str] = (
|
||||
list(getattr(ctx, "texts", []) or []) if ctx is not None else []
|
||||
)
|
||||
use_case_response_cls = cast(
|
||||
type[BaseModel],
|
||||
getattr(ctx, "use_case_response", None) if ctx is not None else None,
|
||||
)
|
||||
|
||||
apply_reliability_flags(
|
||||
provenance_data=response_ix.provenance,
|
||||
use_case_response=use_case_response_cls,
|
||||
texts=texts,
|
||||
)
|
||||
return response_ix
|
||||
|
||||
|
||||
__all__ = ["ReliabilityStep"]
|
||||
67
src/ix/pipeline/response_handler_step.py
Normal file
67
src/ix/pipeline/response_handler_step.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""ResponseHandlerStep — final shape-up before the caller sees the payload (spec §8).
|
||||
|
||||
Does three purely mechanical things:
|
||||
|
||||
1. When ``include_ocr_text`` is set, concatenate every non-tag line text
|
||||
into ``ocr_result.result.text`` (pages joined with blank line).
|
||||
2. When ``include_geometries`` is **not** set (the default), strip
|
||||
``ocr_result.result.pages`` and ``ocr_result.meta_data`` — geometries
|
||||
are heavyweight; callers opt in.
|
||||
3. Clear ``response_ix.context`` (belt-and-braces — ``Field(exclude=True)``
|
||||
already keeps it out of ``model_dump`` output).
|
||||
|
||||
:meth:`validate` always returns True per spec.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from ix.contracts import RequestIX, ResponseIX
|
||||
from ix.pipeline.step import Step
|
||||
|
||||
_PAGE_TAG_RE = re.compile(r"^\s*<\s*/?\s*page\b", re.IGNORECASE)
|
||||
|
||||
|
||||
def _is_page_tag(text: str | None) -> bool:
|
||||
if not text:
|
||||
return False
|
||||
return bool(_PAGE_TAG_RE.match(text))
|
||||
|
||||
|
||||
class ResponseHandlerStep(Step):
|
||||
"""Final shape-up step."""
|
||||
|
||||
async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool:
|
||||
return True
|
||||
|
||||
async def process(
|
||||
self, request_ix: RequestIX, response_ix: ResponseIX
|
||||
) -> ResponseIX:
|
||||
ocr_opts = request_ix.options.ocr
|
||||
|
||||
# 1. Attach flat OCR text if requested.
|
||||
if ocr_opts.include_ocr_text:
|
||||
page_texts: list[str] = []
|
||||
for page in response_ix.ocr_result.result.pages:
|
||||
line_texts = [
|
||||
line.text or ""
|
||||
for line in page.lines
|
||||
if not _is_page_tag(line.text)
|
||||
]
|
||||
page_texts.append("\n".join(line_texts))
|
||||
response_ix.ocr_result.result.text = "\n\n".join(page_texts) or None
|
||||
|
||||
# 2. Strip geometries unless explicitly retained.
|
||||
if not ocr_opts.include_geometries:
|
||||
response_ix.ocr_result.result.pages = []
|
||||
response_ix.ocr_result.meta_data = {}
|
||||
|
||||
# 3. Drop the internal context — already Field(exclude=True),
|
||||
# this is defense in depth.
|
||||
response_ix.context = None
|
||||
|
||||
return response_ix
|
||||
|
||||
|
||||
__all__ = ["ResponseHandlerStep"]
|
||||
174
src/ix/pipeline/setup_step.py
Normal file
174
src/ix/pipeline/setup_step.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""SetupStep — request validation + file ingestion (spec §6.1).
|
||||
|
||||
Runs first in the pipeline. Responsibilities:
|
||||
|
||||
* Reject nonsense requests up front (``IX_000_000`` / ``IX_000_002``).
|
||||
* Normalise ``Context.files`` entries: plain ``str`` → ``FileRef(url=s, headers={})``.
|
||||
* Download every file in parallel (``asyncio.gather``) via the injected
|
||||
``fetcher`` callable (default: :func:`ix.ingestion.fetch.fetch_file`).
|
||||
* Byte-sniff MIMEs; gate unsupported ones via ``IX_000_005``.
|
||||
* Load the use case pair from :data:`ix.use_cases.REGISTRY`
|
||||
(``IX_001_001`` on miss).
|
||||
* Hand the fetched files + raw texts to the injected ``ingestor`` so
|
||||
``context.pages`` + ``context.page_metadata`` are ready for the OCRStep.
|
||||
|
||||
Dependency-inject the fetcher/ingestor/mime-detector so unit tests stay
|
||||
hermetic — production wires the defaults via the app factory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
from ix.contracts import FileRef, RequestIX, ResponseIX
|
||||
from ix.contracts.response import _InternalContext
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.ingestion import (
|
||||
DocumentIngestor,
|
||||
FetchConfig,
|
||||
detect_mime,
|
||||
fetch_file,
|
||||
require_supported,
|
||||
)
|
||||
from ix.pipeline.step import Step
|
||||
from ix.use_cases import get_use_case
|
||||
from ix.use_cases.inline import build_use_case_classes
|
||||
|
||||
|
||||
class _Fetcher(Protocol):
|
||||
"""Async callable that downloads one file to ``tmp_dir``."""
|
||||
|
||||
async def __call__(
|
||||
self, file_ref: FileRef, tmp_dir: Path, cfg: FetchConfig
|
||||
) -> Path: ...
|
||||
|
||||
|
||||
class _Ingestor(Protocol):
|
||||
"""Turns fetched files + raw texts into a flat Page list."""
|
||||
|
||||
def build_pages(
|
||||
self,
|
||||
files: list[tuple[Path, str]],
|
||||
texts: list[str],
|
||||
) -> tuple[list[Any], list[Any]]: ...
|
||||
|
||||
|
||||
class _MimeDetector(Protocol):
|
||||
"""Returns the canonical MIME string for a local file."""
|
||||
|
||||
def __call__(self, path: Path) -> str: ...
|
||||
|
||||
|
||||
class SetupStep(Step):
|
||||
"""First pipeline step: validate + fetch + MIME + use-case load + pages."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fetcher: _Fetcher | None = None,
|
||||
ingestor: _Ingestor | None = None,
|
||||
tmp_dir: Path | None = None,
|
||||
fetch_config: FetchConfig | None = None,
|
||||
mime_detector: _MimeDetector | None = None,
|
||||
) -> None:
|
||||
self._fetcher: _Fetcher = fetcher or fetch_file
|
||||
self._ingestor: _Ingestor = ingestor or DocumentIngestor()
|
||||
self._tmp_dir = tmp_dir or Path("/tmp/ix")
|
||||
self._fetch_config = fetch_config
|
||||
self._mime_detector: _MimeDetector = mime_detector or detect_mime
|
||||
|
||||
async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool:
|
||||
if request_ix is None: # pragma: no cover - runtime sanity; Pydantic rejects it earlier
|
||||
raise IXException(IXErrorCode.IX_000_000)
|
||||
ctx = request_ix.context
|
||||
if not ctx.files and not ctx.texts:
|
||||
raise IXException(IXErrorCode.IX_000_002)
|
||||
return True
|
||||
|
||||
async def process(
|
||||
self, request_ix: RequestIX, response_ix: ResponseIX
|
||||
) -> ResponseIX:
|
||||
# 1. Load the use-case pair — either from the caller's inline
|
||||
# definition (wins over registry) or from the registry by name.
|
||||
# Done early so an unknown name / bad inline definition fails
|
||||
# before we waste time downloading files.
|
||||
if request_ix.use_case_inline is not None:
|
||||
use_case_request_cls, use_case_response_cls = build_use_case_classes(
|
||||
request_ix.use_case_inline
|
||||
)
|
||||
else:
|
||||
use_case_request_cls, use_case_response_cls = get_use_case(
|
||||
request_ix.use_case
|
||||
)
|
||||
use_case_request = use_case_request_cls()
|
||||
|
||||
# 2. Resolve the per-request scratch directory. ix_id is assigned
|
||||
# by the transport adapter — fall back to request_id for the MVP
|
||||
# unit tests that don't set ix_id.
|
||||
sub = request_ix.ix_id or request_ix.request_id
|
||||
request_tmp = self._tmp_dir / sub
|
||||
request_tmp.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. Normalise file entries to FileRef. Plain strings get empty headers.
|
||||
normalised: list[FileRef] = [
|
||||
FileRef(url=entry, headers={}) if isinstance(entry, str) else entry
|
||||
for entry in request_ix.context.files
|
||||
]
|
||||
|
||||
# 4. Download in parallel. No retry — IX_000_007 propagates.
|
||||
fetch_cfg = self._resolve_fetch_config()
|
||||
local_paths: list[Path] = []
|
||||
if normalised:
|
||||
local_paths = list(
|
||||
await asyncio.gather(
|
||||
*(
|
||||
self._fetcher(file_ref, request_tmp, fetch_cfg)
|
||||
for file_ref in normalised
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# 5. MIME-sniff each download; reject unsupported.
|
||||
files_with_mime: list[tuple[Path, str]] = []
|
||||
for path in local_paths:
|
||||
mime = self._mime_detector(path)
|
||||
require_supported(mime)
|
||||
files_with_mime.append((path, mime))
|
||||
|
||||
# 6. Build the flat Page list + per-page metadata.
|
||||
pages, page_metadata = self._ingestor.build_pages(
|
||||
files=files_with_mime,
|
||||
texts=list(request_ix.context.texts),
|
||||
)
|
||||
|
||||
# 7. Stash everything on the internal context for downstream steps.
|
||||
response_ix.context = _InternalContext(
|
||||
pages=pages,
|
||||
files=files_with_mime,
|
||||
texts=list(request_ix.context.texts),
|
||||
use_case_request=use_case_request,
|
||||
use_case_response=use_case_response_cls,
|
||||
segment_index=None,
|
||||
page_metadata=page_metadata,
|
||||
tmp_dir=request_tmp,
|
||||
)
|
||||
|
||||
# 8. Echo use-case display name onto the public response.
|
||||
response_ix.use_case_name = getattr(use_case_request, "use_case_name", None)
|
||||
|
||||
return response_ix
|
||||
|
||||
def _resolve_fetch_config(self) -> FetchConfig:
|
||||
if self._fetch_config is not None:
|
||||
return self._fetch_config
|
||||
# Fallback defaults — used only when the caller didn't inject one.
|
||||
# Real values land via ix.config in Chunk 3.
|
||||
return FetchConfig(
|
||||
connect_timeout_s=10.0,
|
||||
read_timeout_s=30.0,
|
||||
max_bytes=50 * 1024 * 1024,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["SetupStep"]
|
||||
58
src/ix/pipeline/step.py
Normal file
58
src/ix/pipeline/step.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Step ABC — the pipeline-step contract (spec §3).
|
||||
|
||||
Every pipeline step implements two async hooks:
|
||||
|
||||
* :meth:`Step.validate` — returns ``True`` when the step should run for this
|
||||
request, ``False`` when it should be silently skipped. May raise
|
||||
:class:`~ix.errors.IXException` to abort the pipeline with an error code.
|
||||
* :meth:`Step.process` — does the work; mutates the shared ``response_ix``
|
||||
and returns it. May raise :class:`~ix.errors.IXException` to abort.
|
||||
|
||||
Both hooks are async so steps that need I/O (file download, OCR, LLM) can
|
||||
cooperate with the asyncio event loop without sync-async conversion dances.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ix.contracts import RequestIX, ResponseIX
|
||||
|
||||
|
||||
class Step(ABC):
|
||||
"""Abstract base for every pipeline step.
|
||||
|
||||
Subclasses override both hooks. The pipeline runner guarantees
|
||||
``validate`` is called before ``process`` for a given step, and that
|
||||
``process`` runs iff ``validate`` returned ``True``.
|
||||
|
||||
The :attr:`step_name` property controls the label written to
|
||||
``metadata.timings``. Defaults to the class name so production steps
|
||||
(``SetupStep``, ``OCRStep``, …) log under their own name; test doubles
|
||||
override it with the value under test.
|
||||
"""
|
||||
|
||||
@property
|
||||
def step_name(self) -> str:
|
||||
"""Label used in ``metadata.timings``. Default: class name."""
|
||||
return type(self).__name__
|
||||
|
||||
@abstractmethod
|
||||
async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool:
|
||||
"""Return ``True`` to run :meth:`process`, ``False`` to skip silently.
|
||||
|
||||
Raise :class:`~ix.errors.IXException` to abort the pipeline with an
|
||||
error code on ``response_ix.error``.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def process(
|
||||
self, request_ix: RequestIX, response_ix: ResponseIX
|
||||
) -> ResponseIX:
|
||||
"""Run the step; return the (same, mutated) ``response_ix``.
|
||||
|
||||
Raising :class:`~ix.errors.IXException` aborts the pipeline.
|
||||
"""
|
||||
|
||||
|
||||
__all__ = ["Step"]
|
||||
41
src/ix/provenance/__init__.py
Normal file
41
src/ix/provenance/__init__.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Provenance subsystem — normalisers, mapper, verifier.
|
||||
|
||||
Three pieces compose the reliability check:
|
||||
|
||||
* :mod:`ix.provenance.normalize` — pure text/number/date/IBAN normalisers
|
||||
used to compare OCR snippets to extracted values.
|
||||
* :mod:`ix.provenance.mapper` — resolves LLM-emitted segment IDs to
|
||||
:class:`~ix.contracts.provenance.FieldProvenance` entries.
|
||||
* :mod:`ix.provenance.verify` — per-field-type dispatcher that writes the
|
||||
``provenance_verified`` / ``text_agreement`` flags.
|
||||
|
||||
Only :mod:`normalize` is exported from the package at this step; the mapper
|
||||
and verifier land in task 1.8.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.provenance.mapper import (
|
||||
map_segment_refs_to_provenance,
|
||||
resolve_nested_path,
|
||||
)
|
||||
from ix.provenance.normalize import (
|
||||
normalize_date,
|
||||
normalize_iban,
|
||||
normalize_number,
|
||||
normalize_string,
|
||||
should_skip_text_agreement,
|
||||
)
|
||||
from ix.provenance.verify import apply_reliability_flags, verify_field
|
||||
|
||||
__all__ = [
|
||||
"apply_reliability_flags",
|
||||
"map_segment_refs_to_provenance",
|
||||
"normalize_date",
|
||||
"normalize_iban",
|
||||
"normalize_number",
|
||||
"normalize_string",
|
||||
"resolve_nested_path",
|
||||
"should_skip_text_agreement",
|
||||
"verify_field",
|
||||
]
|
||||
145
src/ix/provenance/mapper.py
Normal file
145
src/ix/provenance/mapper.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""Maps LLM-emitted :class:`SegmentCitation` lists to :class:`ProvenanceData`.
|
||||
|
||||
Implements spec §9.4. The algorithm is deliberately small:
|
||||
|
||||
1. For each citation, pick the seg-id list (``value`` vs. ``value_and_context``).
|
||||
2. Cap at ``max_sources_per_field``.
|
||||
3. Resolve each ID via :meth:`SegmentIndex.lookup_segment`; count misses.
|
||||
4. Resolve the field's value by dot-path traversal of the extraction result.
|
||||
5. Build a :class:`FieldProvenance`. Skip fields that resolved to zero sources.
|
||||
|
||||
No verification / normalisation happens here — this module's sole job is
|
||||
structural assembly. :mod:`ix.provenance.verify` does the reliability pass
|
||||
downstream.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
|
||||
from ix.contracts.provenance import (
|
||||
ExtractionSource,
|
||||
FieldProvenance,
|
||||
ProvenanceData,
|
||||
SegmentCitation,
|
||||
)
|
||||
from ix.segmentation import SegmentIndex
|
||||
|
||||
SourceType = Literal["value", "value_and_context"]
|
||||
|
||||
|
||||
_BRACKET_RE = re.compile(r"\[(\d+)\]")
|
||||
|
||||
|
||||
def resolve_nested_path(data: Any, path: str) -> Any:
|
||||
"""Resolve a dot-path into ``data`` with ``[N]`` array notation normalised.
|
||||
|
||||
``"result.items[0].name"`` → walks ``data["result"]["items"][0]["name"]``.
|
||||
Returns ``None`` at any missing-key / index-out-of-range step so callers
|
||||
can fall back to recording the field with a null value.
|
||||
"""
|
||||
normalised = _BRACKET_RE.sub(r".\1", path)
|
||||
cur: Any = data
|
||||
for part in normalised.split("."):
|
||||
if cur is None:
|
||||
return None
|
||||
if part.isdigit() and isinstance(cur, list):
|
||||
i = int(part)
|
||||
if i < 0 or i >= len(cur):
|
||||
return None
|
||||
cur = cur[i]
|
||||
elif isinstance(cur, dict):
|
||||
cur = cur.get(part)
|
||||
else:
|
||||
return None
|
||||
return cur
|
||||
|
||||
|
||||
def _segment_ids_for_citation(
|
||||
citation: SegmentCitation,
|
||||
source_type: SourceType,
|
||||
) -> list[str]:
|
||||
if source_type == "value":
|
||||
return list(citation.value_segment_ids)
|
||||
# value_and_context
|
||||
return list(citation.value_segment_ids) + list(citation.context_segment_ids)
|
||||
|
||||
|
||||
def map_segment_refs_to_provenance(
|
||||
extraction_result: dict[str, Any],
|
||||
segment_citations: list[SegmentCitation],
|
||||
segment_index: SegmentIndex,
|
||||
max_sources_per_field: int,
|
||||
min_confidence: float, # reserved (no-op for MVP)
|
||||
include_bounding_boxes: bool,
|
||||
source_type: SourceType,
|
||||
) -> ProvenanceData:
|
||||
"""Build a :class:`ProvenanceData` from LLM citations and a SegmentIndex."""
|
||||
# min_confidence is reserved for future use (see spec §2 provenance options).
|
||||
_ = min_confidence
|
||||
|
||||
fields: dict[str, FieldProvenance] = {}
|
||||
invalid_references = 0
|
||||
|
||||
for citation in segment_citations:
|
||||
seg_ids = _segment_ids_for_citation(citation, source_type)[:max_sources_per_field]
|
||||
sources: list[ExtractionSource] = []
|
||||
for seg_id in seg_ids:
|
||||
pos = segment_index.lookup_segment(seg_id)
|
||||
if pos is None:
|
||||
invalid_references += 1
|
||||
continue
|
||||
sources.append(
|
||||
ExtractionSource(
|
||||
page_number=pos["page"],
|
||||
file_index=pos.get("file_index"),
|
||||
bounding_box=pos["bbox"] if include_bounding_boxes else None,
|
||||
text_snippet=pos["text"],
|
||||
relevance_score=1.0,
|
||||
segment_id=seg_id,
|
||||
)
|
||||
)
|
||||
if not sources:
|
||||
continue
|
||||
|
||||
value = resolve_nested_path(extraction_result, citation.field_path)
|
||||
fields[citation.field_path] = FieldProvenance(
|
||||
field_name=citation.field_path.split(".")[-1],
|
||||
field_path=citation.field_path,
|
||||
value=value,
|
||||
sources=sources,
|
||||
confidence=None,
|
||||
)
|
||||
|
||||
total_fields_in_result = _count_leaf_fields(extraction_result.get("result", {}))
|
||||
coverage_rate: float | None = None
|
||||
if total_fields_in_result > 0:
|
||||
coverage_rate = len(fields) / total_fields_in_result
|
||||
|
||||
return ProvenanceData(
|
||||
fields=fields,
|
||||
quality_metrics={
|
||||
"fields_with_provenance": len(fields),
|
||||
"total_fields": total_fields_in_result or None,
|
||||
"coverage_rate": coverage_rate,
|
||||
"invalid_references": invalid_references,
|
||||
},
|
||||
segment_count=len(segment_index._ordered_ids),
|
||||
granularity=segment_index.granularity,
|
||||
)
|
||||
|
||||
|
||||
def _count_leaf_fields(data: Any) -> int:
|
||||
"""Count non-container leaves (str/int/float/Decimal/date/bool/None) recursively."""
|
||||
if data is None:
|
||||
return 1
|
||||
if isinstance(data, dict):
|
||||
if not data:
|
||||
return 0
|
||||
return sum(_count_leaf_fields(v) for v in data.values())
|
||||
if isinstance(data, list):
|
||||
if not data:
|
||||
return 0
|
||||
return sum(_count_leaf_fields(v) for v in data)
|
||||
return 1
|
||||
181
src/ix/provenance/normalize.py
Normal file
181
src/ix/provenance/normalize.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
"""Pure normalisers used by the reliability check (spec §6).
|
||||
|
||||
The ReliabilityStep compares extracted values against OCR segment snippets
|
||||
(and raw ``context.texts``) after passing both sides through the same
|
||||
normaliser. Keeping these functions pure (no IO, no state) means the
|
||||
ReliabilityStep itself can stay a thin dispatcher and every rule is
|
||||
directly unit-testable.
|
||||
|
||||
All normalisers return ``str`` so the downstream ``substring`` / ``equals``
|
||||
comparison is trivial.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, get_origin
|
||||
|
||||
from dateutil import parser as _dateparser
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# String
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Strip punctuation that rarely carries semantics in extracted vs. OCR compare:
|
||||
# colon, comma, dot, semicolon, paren/bracket, slash, exclamation, question.
|
||||
_PUNCTUATION_RE = re.compile(r"[.,:;!?()\[\]{}/\\'\"`]")
|
||||
_WHITESPACE_RE = re.compile(r"\s+")
|
||||
|
||||
|
||||
def normalize_string(s: str) -> str:
|
||||
"""NFKC + casefold + punctuation strip + whitespace collapse."""
|
||||
s = unicodedata.normalize("NFKC", s)
|
||||
s = s.casefold()
|
||||
s = _PUNCTUATION_RE.sub(" ", s)
|
||||
s = _WHITESPACE_RE.sub(" ", s).strip()
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Number
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Strip currency symbols / codes and everything that isn't a digit, sign,
|
||||
# apostrophe, dot, or comma. The apostrophe/dot/comma handling is done in a
|
||||
# second pass that figures out thousands-separator vs. decimal-separator from
|
||||
# structure.
|
||||
_NUMERIC_KEEP_RE = re.compile(r"[^0-9.,'\s\-+]")
|
||||
|
||||
|
||||
def _parse_numeric_string(raw: str) -> Decimal:
|
||||
"""Heuristically decode localised numbers.
|
||||
|
||||
Rules:
|
||||
|
||||
* Strip anything that isn't a digit, sign, dot, comma, apostrophe, or
|
||||
whitespace (this drops currency symbols / codes).
|
||||
* Apostrophes are always thousands separators (Swiss-German style).
|
||||
* Whitespace is always a thousands separator (fr-FR style).
|
||||
* If both ``.`` and ``,`` appear, the rightmost is the decimal separator
|
||||
and the other is the thousands separator.
|
||||
* If only one of them appears: assume it's the decimal separator when it
|
||||
has exactly 2 trailing digits, otherwise a thousands separator.
|
||||
"""
|
||||
cleaned = _NUMERIC_KEEP_RE.sub("", raw).strip()
|
||||
cleaned = cleaned.replace("'", "").replace(" ", "")
|
||||
|
||||
has_dot = "." in cleaned
|
||||
has_comma = "," in cleaned
|
||||
|
||||
if has_dot and has_comma:
|
||||
if cleaned.rfind(".") > cleaned.rfind(","):
|
||||
# dot is decimal
|
||||
cleaned = cleaned.replace(",", "")
|
||||
else:
|
||||
# comma is decimal
|
||||
cleaned = cleaned.replace(".", "").replace(",", ".")
|
||||
elif has_comma:
|
||||
# Only comma — treat as decimal if 2 digits follow, else thousands.
|
||||
tail = cleaned.split(",")[-1]
|
||||
if len(tail) == 2 and tail.isdigit():
|
||||
cleaned = cleaned.replace(",", ".")
|
||||
else:
|
||||
cleaned = cleaned.replace(",", "")
|
||||
elif has_dot:
|
||||
# Only dot — same heuristic in reverse. If multiple dots appear they
|
||||
# must be thousands separators (e.g. "1.234.567"); strip them. A
|
||||
# single dot with a non-2-digit tail stays as-is (1.5 is 1.5).
|
||||
tail = cleaned.split(".")[-1]
|
||||
if (len(tail) != 2 or not tail.isdigit()) and cleaned.count(".") > 1:
|
||||
cleaned = cleaned.replace(".", "")
|
||||
if cleaned in ("", "+", "-"):
|
||||
raise InvalidOperation(f"cannot parse number: {raw!r}")
|
||||
return Decimal(cleaned)
|
||||
|
||||
|
||||
def normalize_number(value: int | float | Decimal | str) -> str:
|
||||
"""Return ``"[-]DDD.DD"`` canonical form — always 2 decimal places.
|
||||
|
||||
Accepts localized strings (``"CHF 1'234.56"``, ``"1.234,56 EUR"``,
|
||||
``"-123.45"``) as well as native numeric types.
|
||||
"""
|
||||
if isinstance(value, Decimal):
|
||||
dec = value
|
||||
elif isinstance(value, (int, float)):
|
||||
dec = Decimal(str(value))
|
||||
else:
|
||||
dec = _parse_numeric_string(value)
|
||||
# Quantize to 2dp; keep sign.
|
||||
quantized = dec.quantize(Decimal("0.01"))
|
||||
return format(quantized, "f")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Date
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def normalize_date(value: date | datetime | str) -> str:
|
||||
"""Parse via dateutil (dayfirst=True) and return ISO ``YYYY-MM-DD``."""
|
||||
if isinstance(value, datetime):
|
||||
return value.date().isoformat()
|
||||
if isinstance(value, date):
|
||||
return value.isoformat()
|
||||
parsed = _dateparser.parse(value, dayfirst=True)
|
||||
return parsed.date().isoformat()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# IBAN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def normalize_iban(s: str) -> str:
|
||||
"""Upper-case + strip all whitespace. No format validation (call site's job)."""
|
||||
return "".join(s.split()).upper()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Short-value skip rule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def should_skip_text_agreement(value: Any, field_type: Any) -> bool:
|
||||
"""Return True when ``text_agreement`` should be recorded as ``None``.
|
||||
|
||||
Rules (spec §6 ReliabilityStep):
|
||||
|
||||
1. ``value is None`` → skip.
|
||||
2. ``field_type`` is a ``Literal[...]`` → skip (enum labels don't appear
|
||||
verbatim in the source text).
|
||||
3. Stringified value length ≤ 2 chars → skip (short strings collide with
|
||||
random OCR noise).
|
||||
4. Numeric value (int/float/Decimal) with ``|v| < 10`` → skip.
|
||||
|
||||
``provenance_verified`` still runs for all of these — the bbox-anchored
|
||||
cite is stronger than a global text scan for short values.
|
||||
"""
|
||||
if value is None:
|
||||
return True
|
||||
|
||||
# Literal check — Python 3.12 returns `typing.Literal` from get_origin.
|
||||
import typing
|
||||
|
||||
if get_origin(field_type) is typing.Literal:
|
||||
return True
|
||||
|
||||
# Numeric short-value rule — check before the stringified-length rule so
|
||||
# that "10" (len 2) is still considered on the numeric side. Booleans
|
||||
# are a subtype of int; we exclude them so they fall through to the
|
||||
# string rule ("True" has len 4 so it doesn't trip anyway).
|
||||
if not isinstance(value, bool) and isinstance(value, (int, float, Decimal)):
|
||||
try:
|
||||
return abs(Decimal(str(value))) < 10
|
||||
except InvalidOperation:
|
||||
pass
|
||||
|
||||
# Stringified length rule (strings and anything not numeric).
|
||||
return len(str(value)) <= 2
|
||||
231
src/ix/provenance/verify.py
Normal file
231
src/ix/provenance/verify.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
"""Reliability verifier — writes `provenance_verified` + `text_agreement`.
|
||||
|
||||
Implements the dispatch table from spec §6 ReliabilityStep. The normalisers
|
||||
in :mod:`ix.provenance.normalize` do the actual string/number/date work; this
|
||||
module chooses which one to run for each field based on its Pydantic
|
||||
annotation, and writes the two reliability flags onto the existing
|
||||
:class:`FieldProvenance` records in place.
|
||||
|
||||
The summary counters (``verified_fields``, ``text_agreement_fields``) land in
|
||||
``provenance.quality_metrics`` alongside the existing coverage metrics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import types
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Literal, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ix.contracts.provenance import FieldProvenance, ProvenanceData
|
||||
from ix.provenance.normalize import (
|
||||
normalize_date,
|
||||
normalize_iban,
|
||||
normalize_number,
|
||||
normalize_string,
|
||||
should_skip_text_agreement,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap_optional(tp: Any) -> Any:
|
||||
"""Return the non-None arm of ``Optional[X]`` / ``X | None``.
|
||||
|
||||
Handles both ``typing.Union`` and ``types.UnionType`` (PEP 604 ``X | Y``
|
||||
unions), which ``get_type_hints`` returns on Python 3.12+.
|
||||
"""
|
||||
origin = get_origin(tp)
|
||||
if origin is Union or origin is types.UnionType:
|
||||
args = [a for a in get_args(tp) if a is not type(None)]
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return tp
|
||||
|
||||
|
||||
def _is_literal(tp: Any) -> bool:
|
||||
return get_origin(tp) is Literal
|
||||
|
||||
|
||||
def _is_iban_field(field_path: str) -> bool:
|
||||
return "iban" in field_path.lower()
|
||||
|
||||
|
||||
def _compare_string(value: str, snippet: str) -> bool:
|
||||
try:
|
||||
return normalize_string(value) in normalize_string(snippet)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
_NUMBER_TOKEN_RE = re.compile(r"[+\-]?[\d][\d',.\s]*\d|[+\-]?\d+")
|
||||
|
||||
|
||||
def _compare_number(value: Any, snippet: str) -> bool:
|
||||
try:
|
||||
canonical = normalize_number(value)
|
||||
except (InvalidOperation, ValueError):
|
||||
return False
|
||||
# Try the whole snippet first (cheap path when the snippet IS the number).
|
||||
try:
|
||||
if normalize_number(snippet) == canonical:
|
||||
return True
|
||||
except (InvalidOperation, ValueError):
|
||||
pass
|
||||
# Fall back to scanning numeric substrings — OCR snippets commonly carry
|
||||
# labels ("Closing balance CHF 1'234.56") that confuse a whole-string
|
||||
# numeric parse.
|
||||
for match in _NUMBER_TOKEN_RE.finditer(snippet):
|
||||
token = match.group()
|
||||
try:
|
||||
if normalize_number(token) == canonical:
|
||||
return True
|
||||
except (InvalidOperation, ValueError):
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
def _compare_date(value: Any, snippet: str) -> bool:
|
||||
try:
|
||||
iso_value = normalize_date(value)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
# Find any date-like chunk in snippet; try normalising each token segment.
|
||||
# Simplest heuristic: try snippet as a whole; on failure, scan tokens.
|
||||
try:
|
||||
if normalize_date(snippet) == iso_value:
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
# Token scan — dateutil will raise on the non-date tokens, which is fine.
|
||||
for token in _tokenise_for_date(snippet):
|
||||
try:
|
||||
if normalize_date(token) == iso_value:
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
def _tokenise_for_date(s: str) -> list[str]:
|
||||
"""Split on whitespace + common punctuation so date strings survive whole.
|
||||
|
||||
Keeps dots / slashes / dashes inside tokens (they're valid date
|
||||
separators); splits on spaces, commas, semicolons, colons, brackets.
|
||||
"""
|
||||
return [t for t in re.split(r"[\s,;:()\[\]]+", s) if t]
|
||||
|
||||
|
||||
def _compare_iban(value: str, snippet: str) -> bool:
|
||||
try:
|
||||
return normalize_iban(value) in normalize_iban(snippet)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _compare_for_type(value: Any, field_type: Any, snippet: str, field_path: str) -> bool:
|
||||
unwrapped = _unwrap_optional(field_type)
|
||||
|
||||
# Date / datetime.
|
||||
if unwrapped in (date, datetime):
|
||||
return _compare_date(value, snippet)
|
||||
|
||||
# Numeric.
|
||||
if unwrapped in (int, float, Decimal):
|
||||
return _compare_number(value, snippet)
|
||||
|
||||
# IBAN (detected by field name).
|
||||
if _is_iban_field(field_path):
|
||||
return _compare_iban(str(value), snippet)
|
||||
|
||||
# Default: string substring.
|
||||
return _compare_string(str(value), snippet)
|
||||
|
||||
|
||||
def verify_field(
|
||||
field_provenance: FieldProvenance,
|
||||
field_type: Any,
|
||||
texts: list[str],
|
||||
) -> tuple[bool | None, bool | None]:
|
||||
"""Compute the two reliability flags for one field.
|
||||
|
||||
Returns ``(provenance_verified, text_agreement)``. See spec §6 for the
|
||||
dispatch rules. ``None`` on either slot means the check was skipped
|
||||
(Literal, None value, or short-value for text_agreement).
|
||||
"""
|
||||
value = field_provenance.value
|
||||
unwrapped = _unwrap_optional(field_type)
|
||||
|
||||
# Skip Literal / None value entirely — both flags None.
|
||||
if _is_literal(unwrapped) or value is None:
|
||||
return (None, None)
|
||||
|
||||
# provenance_verified: scan cited segments.
|
||||
provenance_verified: bool | None
|
||||
if not field_provenance.sources:
|
||||
provenance_verified = False
|
||||
else:
|
||||
provenance_verified = any(
|
||||
_compare_for_type(value, field_type, s.text_snippet, field_provenance.field_path)
|
||||
for s in field_provenance.sources
|
||||
)
|
||||
|
||||
# text_agreement: None if no texts, else apply short-value rule.
|
||||
text_agreement: bool | None
|
||||
if not texts or should_skip_text_agreement(value, field_type):
|
||||
text_agreement = None
|
||||
else:
|
||||
concatenated = "\n".join(texts)
|
||||
text_agreement = _compare_for_type(
|
||||
value, field_type, concatenated, field_provenance.field_path
|
||||
)
|
||||
|
||||
return provenance_verified, text_agreement
|
||||
|
||||
|
||||
def apply_reliability_flags(
|
||||
provenance_data: ProvenanceData,
|
||||
use_case_response: type[BaseModel],
|
||||
texts: list[str],
|
||||
) -> None:
|
||||
"""Apply :func:`verify_field` to every field in ``provenance_data``.
|
||||
|
||||
Mutates ``provenance_data`` in place:
|
||||
|
||||
* Each ``FieldProvenance``'s ``provenance_verified`` and
|
||||
``text_agreement`` slots are filled.
|
||||
* ``quality_metrics['verified_fields']`` is set to the number of fields
|
||||
whose ``provenance_verified`` is True.
|
||||
* ``quality_metrics['text_agreement_fields']`` is set likewise for
|
||||
``text_agreement``.
|
||||
|
||||
``use_case_response`` is the Pydantic class for the extraction schema
|
||||
(e.g. :class:`~ix.use_cases.bank_statement_header.BankStatementHeader`).
|
||||
Type hints are resolved via ``get_type_hints`` so forward-refs and
|
||||
``str | None`` unions are normalised consistently.
|
||||
"""
|
||||
type_hints = get_type_hints(use_case_response)
|
||||
|
||||
verified_count = 0
|
||||
text_agreement_count = 0
|
||||
for fp in provenance_data.fields.values():
|
||||
# Field path is something like "result.bank_name" — the part after
|
||||
# the first dot is the attribute name on the response schema.
|
||||
leaf = fp.field_path.split(".", 1)[-1]
|
||||
# For nested shapes we only resolve the top-level name; MVP use cases
|
||||
# are flat so that's enough. When we ship nested schemas we'll walk
|
||||
# the annotation tree here.
|
||||
top_attr = leaf.split(".")[0]
|
||||
field_type: Any = type_hints.get(top_attr, str)
|
||||
|
||||
pv, ta = verify_field(fp, field_type, texts)
|
||||
fp.provenance_verified = pv
|
||||
fp.text_agreement = ta
|
||||
if pv is True:
|
||||
verified_count += 1
|
||||
if ta is True:
|
||||
text_agreement_count += 1
|
||||
|
||||
provenance_data.quality_metrics["verified_fields"] = verified_count
|
||||
provenance_data.quality_metrics["text_agreement_fields"] = text_agreement_count
|
||||
7
src/ix/segmentation/__init__.py
Normal file
7
src/ix/segmentation/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Segment-index module: maps short IDs (``p1_l0``) to on-page anchors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.segmentation.segment_index import PageMetadata, SegmentIndex
|
||||
|
||||
__all__ = ["PageMetadata", "SegmentIndex"]
|
||||
140
src/ix/segmentation/segment_index.py
Normal file
140
src/ix/segmentation/segment_index.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
"""SegmentIndex — maps short segment IDs (`p1_l0`) to their on-page anchors.
|
||||
|
||||
Per spec §9.1. Built from an :class:`OCRResult` + a parallel list of
|
||||
:class:`PageMetadata` entries carrying the 0-based ``file_index`` for each
|
||||
flat-list page. Used by two places:
|
||||
|
||||
1. :class:`GenAIStep` calls :meth:`SegmentIndex.to_prompt_text` to feed the
|
||||
LLM a segment-tagged user message.
|
||||
2. :func:`ix.provenance.mapper.map_segment_refs_to_provenance` calls
|
||||
:meth:`SegmentIndex.lookup_segment` to resolve each LLM-cited segment ID
|
||||
back to its page + bbox + text snippet.
|
||||
|
||||
Page-tag lines (``<page …>`` / ``</page>``) emitted by the OCR step for
|
||||
visual grounding are explicitly *excluded* from IDs so the LLM can never
|
||||
cite them as provenance.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ix.contracts.provenance import BoundingBox
|
||||
from ix.contracts.response import OCRResult
|
||||
|
||||
_PAGE_TAG_RE = re.compile(r"^\s*<\s*/?\s*page\b", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PageMetadata:
|
||||
"""Per-page metadata travelling alongside an :class:`OCRResult`.
|
||||
|
||||
The OCR engine doesn't know which input file a page came from — that
|
||||
mapping lives in the pipeline's internal context. Passing it in
|
||||
explicitly keeps the segmentation module decoupled from the ingestion
|
||||
module.
|
||||
"""
|
||||
|
||||
file_index: int | None = None
|
||||
|
||||
|
||||
def _is_page_tag(text: str | None) -> bool:
|
||||
if not text:
|
||||
return False
|
||||
return bool(_PAGE_TAG_RE.match(text))
|
||||
|
||||
|
||||
def _normalize_bbox(coords: list[float], width: float, height: float) -> BoundingBox:
|
||||
"""8-coord polygon → 0-1 normalised. Zero dimensions → pass through."""
|
||||
if width <= 0 or height <= 0 or len(coords) != 8:
|
||||
return BoundingBox(coordinates=list(coords))
|
||||
normalised = [
|
||||
coords[0] / width,
|
||||
coords[1] / height,
|
||||
coords[2] / width,
|
||||
coords[3] / height,
|
||||
coords[4] / width,
|
||||
coords[5] / height,
|
||||
coords[6] / width,
|
||||
coords[7] / height,
|
||||
]
|
||||
return BoundingBox(coordinates=normalised)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SegmentIndex:
|
||||
"""Read-mostly segment lookup built once per request.
|
||||
|
||||
``_id_to_position`` is a flat dict for O(1) lookup; ``_ordered_ids``
|
||||
preserves insertion order so ``to_prompt_text`` produces a stable
|
||||
deterministic rendering (matters for prompt caching).
|
||||
"""
|
||||
|
||||
granularity: str = "line"
|
||||
_id_to_position: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
_ordered_ids: list[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
ocr_result: OCRResult,
|
||||
*,
|
||||
granularity: str = "line",
|
||||
pages_metadata: list[PageMetadata],
|
||||
) -> SegmentIndex:
|
||||
"""Construct a new index from an OCR result + per-page metadata.
|
||||
|
||||
``pages_metadata`` must have the same length as
|
||||
``ocr_result.result.pages`` and in the same order. The pipeline
|
||||
builds both in :class:`SetupStep` / :class:`OCRStep` before calling
|
||||
here.
|
||||
"""
|
||||
idx = cls(granularity=granularity)
|
||||
pages = ocr_result.result.pages
|
||||
for global_pos, page in enumerate(pages, start=1):
|
||||
meta = pages_metadata[global_pos - 1] if global_pos - 1 < len(pages_metadata) else PageMetadata()
|
||||
line_idx_in_page = 0
|
||||
for line in page.lines:
|
||||
if _is_page_tag(line.text):
|
||||
continue
|
||||
seg_id = f"p{global_pos}_l{line_idx_in_page}"
|
||||
bbox = _normalize_bbox(
|
||||
list(line.bounding_box),
|
||||
float(page.width),
|
||||
float(page.height),
|
||||
)
|
||||
idx._id_to_position[seg_id] = {
|
||||
"page": global_pos,
|
||||
"bbox": bbox,
|
||||
"text": line.text or "",
|
||||
"file_index": meta.file_index,
|
||||
}
|
||||
idx._ordered_ids.append(seg_id)
|
||||
line_idx_in_page += 1
|
||||
return idx
|
||||
|
||||
def lookup_segment(self, segment_id: str) -> dict[str, Any] | None:
|
||||
"""Return the stored position dict, or ``None`` if the ID is unknown."""
|
||||
return self._id_to_position.get(segment_id)
|
||||
|
||||
def to_prompt_text(self, context_texts: list[str] | None = None) -> str:
|
||||
"""Emit the LLM-facing user text.
|
||||
|
||||
Format (deterministic):
|
||||
|
||||
- One ``[pN_lM] text`` line per indexed segment, in insertion order.
|
||||
- If ``context_texts`` is non-empty: a blank line, then each entry on
|
||||
its own line, untagged. This mirrors spec §7.2b — pre-OCR text
|
||||
(Paperless, etc.) is added as untagged context for the model to
|
||||
cross-reference against the tagged document body.
|
||||
"""
|
||||
lines: list[str] = [
|
||||
f"[{seg_id}] {self._id_to_position[seg_id]['text']}" for seg_id in self._ordered_ids
|
||||
]
|
||||
out = "\n".join(lines)
|
||||
if context_texts:
|
||||
suffix = "\n\n".join(context_texts)
|
||||
out = f"{out}\n\n{suffix}" if out else suffix
|
||||
return out
|
||||
20
src/ix/store/__init__.py
Normal file
20
src/ix/store/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""Async Postgres job store — SQLAlchemy 2.0 ORM + repo.
|
||||
|
||||
Exports are intentionally minimal: the engine factory and the declarative
|
||||
``Base`` + ``IxJob`` ORM. The ``jobs_repo`` module lives next to this one and
|
||||
exposes the CRUD methods callers actually need; we don't re-export from the
|
||||
package so it stays obvious where each function lives.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.store.engine import get_engine, get_session_factory, reset_engine
|
||||
from ix.store.models import Base, IxJob
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"IxJob",
|
||||
"get_engine",
|
||||
"get_session_factory",
|
||||
"reset_engine",
|
||||
]
|
||||
76
src/ix/store/engine.py
Normal file
76
src/ix/store/engine.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""Lazy async engine + session-factory singletons.
|
||||
|
||||
The factories read ``IX_POSTGRES_URL`` from the environment on first call. In
|
||||
Task 3.2 this switches to ``get_config()``; for now we go through ``os.environ``
|
||||
directly so the store module doesn't depend on config that doesn't exist yet.
|
||||
|
||||
Both factories are idempotent on success — repeat calls return the same
|
||||
engine / sessionmaker. ``reset_engine`` nukes the cache and should only be
|
||||
used in tests (where we teardown-recreate the DB between sessions).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
_engine: AsyncEngine | None = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
def _resolve_url() -> str:
|
||||
"""Grab the Postgres URL from the environment.
|
||||
|
||||
Task 3.2 refactors this to go through ``ix.config.get_config()``; this
|
||||
version keeps the store module usable during the bootstrap window where
|
||||
``ix.config`` doesn't exist yet. Behaviour after refactor is identical —
|
||||
both paths ultimately read ``IX_POSTGRES_URL``.
|
||||
"""
|
||||
|
||||
try:
|
||||
from ix.config import get_config
|
||||
except ImportError:
|
||||
url = os.environ.get("IX_POSTGRES_URL")
|
||||
if not url:
|
||||
raise RuntimeError(
|
||||
"IX_POSTGRES_URL is not set and ix.config is unavailable"
|
||||
) from None
|
||||
return url
|
||||
return get_config().postgres_url
|
||||
|
||||
|
||||
def get_engine() -> AsyncEngine:
|
||||
"""Return the process-wide async engine; create on first call."""
|
||||
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = create_async_engine(_resolve_url(), pool_pre_ping=True)
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
||||
"""Return the process-wide session factory; create on first call.
|
||||
|
||||
``expire_on_commit=False`` so ORM instances stay usable after ``commit()``
|
||||
— we frequently commit inside a repo method and then ``model_validate``
|
||||
the row outside the session.
|
||||
"""
|
||||
|
||||
global _session_factory
|
||||
if _session_factory is None:
|
||||
_session_factory = async_sessionmaker(get_engine(), expire_on_commit=False)
|
||||
return _session_factory
|
||||
|
||||
|
||||
def reset_engine() -> None:
|
||||
"""Drop the cached engine + session factory. Test-only."""
|
||||
|
||||
global _engine, _session_factory
|
||||
_engine = None
|
||||
_session_factory = None
|
||||
414
src/ix/store/jobs_repo.py
Normal file
414
src/ix/store/jobs_repo.py
Normal file
|
|
@ -0,0 +1,414 @@
|
|||
"""Async CRUD over ``ix_jobs`` — the one module the worker / REST touches.
|
||||
|
||||
Every method takes an :class:`AsyncSession` (caller-owned transaction). The
|
||||
caller commits. We don't manage transactions inside repo methods because the
|
||||
worker sometimes needs to claim + run-pipeline + mark-done inside one
|
||||
long-running unit of work, and an inside-the-method commit would break that.
|
||||
|
||||
A few invariants worth stating up front:
|
||||
|
||||
* ``ix_id`` is a 16-char hex string assigned by :func:`insert_pending` on
|
||||
first insert. Callers MUST NOT pass one (we generate it); if a
|
||||
``RequestIX`` arrives with ``ix_id`` set it is ignored.
|
||||
* ``(client_id, request_id)`` is unique — on collision we return the
|
||||
existing row unchanged. Callback URLs on the second insert are ignored;
|
||||
the first insert's metadata wins.
|
||||
* Claim uses ``FOR UPDATE SKIP LOCKED`` so concurrent workers never pick the
|
||||
same row, and a session holding a lock doesn't block a sibling claimer.
|
||||
* Status transitions: ``pending → running → (done | error)``. The sweeper is
|
||||
the only path back to ``pending`` (and only from ``running``); terminal
|
||||
states are stable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from collections.abc import Iterable
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ix.contracts.job import Job
|
||||
from ix.contracts.request import RequestIX
|
||||
from ix.contracts.response import ResponseIX
|
||||
from ix.store.models import IxJob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
def _new_ix_id() -> str:
|
||||
"""Transport-assigned 16-hex handle.
|
||||
|
||||
``secrets.token_hex(8)`` gives 16 characters of entropy; good enough to
|
||||
tag logs per spec §3 without collision risk across the lifetime of the
|
||||
service.
|
||||
"""
|
||||
|
||||
return secrets.token_hex(8)
|
||||
|
||||
|
||||
def _orm_to_job(row: IxJob) -> Job:
|
||||
"""Round-trip ORM row back through the Pydantic ``Job`` contract.
|
||||
|
||||
The JSONB columns come out as plain dicts; we let Pydantic re-validate
|
||||
them into :class:`RequestIX` / :class:`ResponseIX`. Catching validation
|
||||
errors here would mask real bugs; we let them surface.
|
||||
"""
|
||||
|
||||
return Job(
|
||||
job_id=row.job_id,
|
||||
ix_id=row.ix_id,
|
||||
client_id=row.client_id,
|
||||
request_id=row.request_id,
|
||||
status=row.status, # type: ignore[arg-type]
|
||||
request=RequestIX.model_validate(row.request),
|
||||
response=(
|
||||
ResponseIX.model_validate(row.response) if row.response is not None else None
|
||||
),
|
||||
callback_url=row.callback_url,
|
||||
callback_status=row.callback_status, # type: ignore[arg-type]
|
||||
attempts=row.attempts,
|
||||
created_at=row.created_at,
|
||||
started_at=row.started_at,
|
||||
finished_at=row.finished_at,
|
||||
)
|
||||
|
||||
|
||||
async def insert_pending(
|
||||
session: AsyncSession,
|
||||
request: RequestIX,
|
||||
callback_url: str | None,
|
||||
) -> Job:
|
||||
"""Insert a pending row; return the new or existing :class:`Job`.
|
||||
|
||||
Uses ``INSERT ... ON CONFLICT DO NOTHING`` on the
|
||||
``(client_id, request_id)`` unique index, then re-selects. If the insert
|
||||
was a no-op the existing row is returned verbatim (status / callback_url
|
||||
unchanged) — callers rely on this for idempotent resubmission.
|
||||
"""
|
||||
|
||||
ix_id = request.ix_id or _new_ix_id()
|
||||
job_id = uuid4()
|
||||
|
||||
# Serialise the request through Pydantic so JSONB gets plain JSON types,
|
||||
# not datetime / Decimal instances asyncpg would reject.
|
||||
request_json = request.model_copy(update={"ix_id": ix_id}).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
|
||||
stmt = (
|
||||
pg_insert(IxJob)
|
||||
.values(
|
||||
job_id=job_id,
|
||||
ix_id=ix_id,
|
||||
client_id=request.ix_client_id,
|
||||
request_id=request.request_id,
|
||||
status="pending",
|
||||
request=request_json,
|
||||
response=None,
|
||||
callback_url=callback_url,
|
||||
callback_status=None,
|
||||
attempts=0,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["client_id", "request_id"])
|
||||
)
|
||||
await session.execute(stmt)
|
||||
|
||||
row = await session.scalar(
|
||||
select(IxJob).where(
|
||||
IxJob.client_id == request.ix_client_id,
|
||||
IxJob.request_id == request.request_id,
|
||||
)
|
||||
)
|
||||
assert row is not None, "insert_pending: row missing after upsert"
|
||||
return _orm_to_job(row)
|
||||
|
||||
|
||||
async def claim_next_pending(session: AsyncSession) -> Job | None:
|
||||
"""Atomically pick the oldest pending row and flip it to running.
|
||||
|
||||
``FOR UPDATE SKIP LOCKED`` means a sibling worker can never deadlock on
|
||||
our row; they'll skip past it and grab the next pending entry. The
|
||||
sibling test in :mod:`tests/integration/test_jobs_repo` asserts this.
|
||||
"""
|
||||
|
||||
stmt = (
|
||||
select(IxJob)
|
||||
.where(IxJob.status == "pending")
|
||||
.order_by(IxJob.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
row = await session.scalar(stmt)
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
row.status = "running"
|
||||
row.started_at = datetime.now(UTC)
|
||||
await session.flush()
|
||||
return _orm_to_job(row)
|
||||
|
||||
|
||||
async def get(session: AsyncSession, job_id: UUID) -> Job | None:
|
||||
row = await session.scalar(select(IxJob).where(IxJob.job_id == job_id))
|
||||
return _orm_to_job(row) if row is not None else None
|
||||
|
||||
|
||||
async def queue_position(
|
||||
session: AsyncSession, job_id: UUID
|
||||
) -> tuple[int, int]:
|
||||
"""Return ``(ahead, total_active)`` for a pending/running job.
|
||||
|
||||
``ahead`` counts active jobs (``pending`` or ``running``) that would be
|
||||
claimed by the worker before this one:
|
||||
|
||||
* any ``running`` job is always ahead — it has the worker already.
|
||||
* other ``pending`` jobs with a strictly older ``created_at`` are ahead
|
||||
(the worker picks pending rows in ``ORDER BY created_at`` per
|
||||
:func:`claim_next_pending`).
|
||||
|
||||
``total_active`` is the total count of ``pending`` + ``running`` rows.
|
||||
|
||||
Terminal jobs (``done`` / ``error``) always return ``(0, 0)`` — there is
|
||||
no meaningful "position" for a finished job.
|
||||
"""
|
||||
|
||||
row = await session.scalar(select(IxJob).where(IxJob.job_id == job_id))
|
||||
if row is None:
|
||||
return (0, 0)
|
||||
if row.status not in ("pending", "running"):
|
||||
return (0, 0)
|
||||
|
||||
total_active = int(
|
||||
await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(IxJob)
|
||||
.where(IxJob.status.in_(("pending", "running")))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
if row.status == "running":
|
||||
# A running row is at the head of the queue for our purposes.
|
||||
return (0, total_active)
|
||||
|
||||
# Pending: count running rows (always ahead) + older pending rows.
|
||||
# We tiebreak on ``job_id`` for deterministic ordering when multiple
|
||||
# rows share a ``created_at`` (e.g. the same transaction inserts two
|
||||
# jobs, which Postgres stamps with identical ``now()`` values).
|
||||
running_ahead = int(
|
||||
await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(IxJob)
|
||||
.where(IxJob.status == "running")
|
||||
)
|
||||
or 0
|
||||
)
|
||||
pending_ahead = int(
|
||||
await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(IxJob)
|
||||
.where(
|
||||
IxJob.status == "pending",
|
||||
(
|
||||
(IxJob.created_at < row.created_at)
|
||||
| (
|
||||
(IxJob.created_at == row.created_at)
|
||||
& (IxJob.job_id < row.job_id)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
return (running_ahead + pending_ahead, total_active)
|
||||
|
||||
|
||||
async def get_by_correlation(
|
||||
session: AsyncSession, client_id: str, request_id: str
|
||||
) -> Job | None:
|
||||
row = await session.scalar(
|
||||
select(IxJob).where(
|
||||
IxJob.client_id == client_id,
|
||||
IxJob.request_id == request_id,
|
||||
)
|
||||
)
|
||||
return _orm_to_job(row) if row is not None else None
|
||||
|
||||
|
||||
async def mark_done(
|
||||
session: AsyncSession, job_id: UUID, response: ResponseIX
|
||||
) -> None:
|
||||
"""Write the pipeline's response and move to terminal state.
|
||||
|
||||
Status is ``done`` iff ``response.error is None``; any non-None error
|
||||
flips us to ``error``. Spec §3 lifecycle invariant.
|
||||
"""
|
||||
|
||||
status = "done" if response.error is None else "error"
|
||||
await session.execute(
|
||||
update(IxJob)
|
||||
.where(IxJob.job_id == job_id)
|
||||
.values(
|
||||
status=status,
|
||||
response=response.model_dump(mode="json"),
|
||||
finished_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def mark_error(
|
||||
session: AsyncSession, job_id: UUID, response: ResponseIX
|
||||
) -> None:
|
||||
"""Convenience wrapper that always writes status='error'.
|
||||
|
||||
Separate from :func:`mark_done` for readability at call sites: when the
|
||||
worker knows it caught an exception the pipeline didn't handle itself,
|
||||
``mark_error`` signals intent even if the response body happens to have
|
||||
a populated error field.
|
||||
"""
|
||||
|
||||
await session.execute(
|
||||
update(IxJob)
|
||||
.where(IxJob.job_id == job_id)
|
||||
.values(
|
||||
status="error",
|
||||
response=response.model_dump(mode="json"),
|
||||
finished_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def update_callback_status(
|
||||
session: AsyncSession,
|
||||
job_id: UUID,
|
||||
status: Literal["delivered", "failed"],
|
||||
) -> None:
|
||||
await session.execute(
|
||||
update(IxJob)
|
||||
.where(IxJob.job_id == job_id)
|
||||
.values(callback_status=status)
|
||||
)
|
||||
|
||||
|
||||
async def sweep_orphans(
|
||||
session: AsyncSession,
|
||||
now: datetime,
|
||||
max_running_seconds: int,
|
||||
) -> list[UUID]:
|
||||
"""Reset stale ``running`` rows back to ``pending`` and bump ``attempts``.
|
||||
|
||||
Called once at worker startup (spec §3) to rescue jobs whose owner died
|
||||
mid-pipeline. The threshold is time-based on ``started_at`` so a still-
|
||||
running worker never reclaims its own in-flight job — callers pass
|
||||
``2 * IX_PIPELINE_REQUEST_TIMEOUT_SECONDS`` per spec.
|
||||
"""
|
||||
|
||||
# Pick candidates and return their ids so the worker can log what it
|
||||
# did. Two-step (SELECT then UPDATE) is clearer than RETURNING for
|
||||
# callers who want the id list alongside a plain UPDATE.
|
||||
candidates = (
|
||||
await session.scalars(
|
||||
select(IxJob.job_id).where(
|
||||
IxJob.status == "running",
|
||||
IxJob.started_at < now - _as_interval(max_running_seconds),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
await session.execute(
|
||||
update(IxJob)
|
||||
.where(IxJob.job_id.in_(candidates))
|
||||
.values(
|
||||
status="pending",
|
||||
started_at=None,
|
||||
attempts=IxJob.attempts + 1,
|
||||
)
|
||||
)
|
||||
return list(candidates)
|
||||
|
||||
|
||||
_LIST_RECENT_LIMIT_CAP = 200
|
||||
|
||||
|
||||
async def list_recent(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
status: str | Iterable[str] | None = None,
|
||||
client_id: str | None = None,
|
||||
) -> tuple[list[Job], int]:
|
||||
"""Return a page of recent jobs, newest first, plus total matching count.
|
||||
|
||||
Powers the ``/ui/jobs`` listing page. Ordering is ``created_at DESC``.
|
||||
``total`` reflects matching rows *before* limit/offset so the template
|
||||
can render "showing N of M".
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit:
|
||||
Maximum rows to return. Capped at
|
||||
:data:`_LIST_RECENT_LIMIT_CAP` (200) to bound the JSON payload
|
||||
size — callers that pass a larger value get clamped silently.
|
||||
offset:
|
||||
Non-negative row offset. Negative values raise ``ValueError``
|
||||
because the template treats offset as a page cursor; a negative
|
||||
cursor is a bug at the call site, not something to paper over.
|
||||
status:
|
||||
If set, restrict to the given status(es). Accepts a single
|
||||
:data:`Job.status` value or any iterable (list/tuple/set). Values
|
||||
outside the lifecycle enum simply match nothing — we don't try
|
||||
to validate here; the DB CHECK constraint already bounds the set.
|
||||
client_id:
|
||||
If set, exact match on :attr:`IxJob.client_id`. No substring /
|
||||
prefix match — simple and predictable.
|
||||
"""
|
||||
|
||||
if offset < 0:
|
||||
raise ValueError(f"offset must be >= 0, got {offset}")
|
||||
effective_limit = max(0, min(limit, _LIST_RECENT_LIMIT_CAP))
|
||||
|
||||
filters = []
|
||||
if status is not None:
|
||||
if isinstance(status, str):
|
||||
filters.append(IxJob.status == status)
|
||||
else:
|
||||
status_list = list(status)
|
||||
if not status_list:
|
||||
# Empty iterable → no rows match. Return a sentinel
|
||||
# IN-list that can never hit so we don't blow up.
|
||||
filters.append(IxJob.status.in_(status_list))
|
||||
else:
|
||||
filters.append(IxJob.status.in_(status_list))
|
||||
if client_id is not None:
|
||||
filters.append(IxJob.client_id == client_id)
|
||||
|
||||
total_q = select(func.count()).select_from(IxJob)
|
||||
list_q = select(IxJob).order_by(IxJob.created_at.desc())
|
||||
for f in filters:
|
||||
total_q = total_q.where(f)
|
||||
list_q = list_q.where(f)
|
||||
|
||||
total = int(await session.scalar(total_q) or 0)
|
||||
rows = (
|
||||
await session.scalars(list_q.limit(effective_limit).offset(offset))
|
||||
).all()
|
||||
return [_orm_to_job(r) for r in rows], total
|
||||
|
||||
|
||||
def _as_interval(seconds: int): # type: ignore[no-untyped-def]
|
||||
"""Return a SQL interval expression for ``seconds``.
|
||||
|
||||
We build the interval via ``func.make_interval`` so asyncpg doesn't have
|
||||
to guess at a text-form cast — the server-side ``make_interval(secs :=)``
|
||||
is unambiguous and avoids locale-dependent parsing.
|
||||
"""
|
||||
|
||||
return func.make_interval(0, 0, 0, 0, 0, 0, seconds)
|
||||
86
src/ix/store/models.py
Normal file
86
src/ix/store/models.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""SQLAlchemy 2.0 ORM for ``ix_jobs``.
|
||||
|
||||
Shape matches the initial migration (``alembic/versions/001_initial_ix_jobs.py``)
|
||||
which in turn matches spec §4. JSONB columns carry the RequestIX / ResponseIX
|
||||
Pydantic payloads; we don't wrap them in custom TypeDecorators — the repo does
|
||||
an explicit ``model_dump(mode="json")`` on write and ``model_validate`` on read
|
||||
so the ORM stays a thin mapping layer and the Pydantic round-trip logic stays
|
||||
colocated with the other contract code.
|
||||
|
||||
The status column is a plain string — the CHECK constraint in the DB enforces
|
||||
the allowed values. Using a SQLAlchemy ``Enum`` type here would double-bind
|
||||
the enum values on both sides and force a migration each time we add a state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import CheckConstraint, DateTime, Index, Integer, Text, text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.dialects.postgresql import UUID as PgUUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Shared declarative base for the store package."""
|
||||
|
||||
|
||||
class IxJob(Base):
|
||||
"""ORM mapping for the ``ix_jobs`` table.
|
||||
|
||||
One row per submitted extraction job. Lifecycle: pending → running →
|
||||
(done | error). The worker is the only writer that flips status past
|
||||
pending; the REST / pg_queue adapters only insert.
|
||||
"""
|
||||
|
||||
__tablename__ = "ix_jobs"
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"status IN ('pending', 'running', 'done', 'error')",
|
||||
name="ix_jobs_status_check",
|
||||
),
|
||||
CheckConstraint(
|
||||
"callback_status IS NULL OR callback_status IN "
|
||||
"('pending', 'delivered', 'failed')",
|
||||
name="ix_jobs_callback_status_check",
|
||||
),
|
||||
Index(
|
||||
"ix_jobs_status_created",
|
||||
"status",
|
||||
"created_at",
|
||||
postgresql_where=text("status = 'pending'"),
|
||||
),
|
||||
Index(
|
||||
"ix_jobs_client_request",
|
||||
"client_id",
|
||||
"request_id",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
job_id: Mapped[UUID] = mapped_column(PgUUID(as_uuid=True), primary_key=True)
|
||||
ix_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
client_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
request_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
status: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
request: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
|
||||
response: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True)
|
||||
callback_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
callback_status: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
attempts: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, server_default=text("0")
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text("now()"),
|
||||
)
|
||||
started_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
finished_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
13
src/ix/ui/__init__.py
Normal file
13
src/ix/ui/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Minimal browser UI served alongside the REST API at ``/ui``.
|
||||
|
||||
The module is intentionally thin: templates + HTMX + Pico CSS (all from
|
||||
CDNs, no build step). Uploads land in ``{cfg.tmp_dir}/ui/<uuid>.pdf`` and
|
||||
are submitted through the same :func:`ix.store.jobs_repo.insert_pending`
|
||||
entry point the REST adapter uses — the UI does not duplicate that logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.ui.routes import build_router
|
||||
|
||||
__all__ = ["build_router"]
|
||||
568
src/ix/ui/routes.py
Normal file
568
src/ix/ui/routes.py
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
"""``/ui`` router — thin HTML wrapper over the existing jobs pipeline.
|
||||
|
||||
Design notes:
|
||||
|
||||
* Uploads stream to ``{cfg.tmp_dir}/ui/{uuid4()}.pdf`` via aiofiles; the
|
||||
file persists for the lifetime of the ``ix_id`` (no cleanup cron — spec
|
||||
deferred).
|
||||
* The submission handler builds a :class:`RequestIX` (inline use case
|
||||
supported) and inserts it via the same
|
||||
:func:`ix.store.jobs_repo.insert_pending` the REST adapter uses.
|
||||
* Responses are HTML. For HTMX-triggered submissions the handler returns
|
||||
``HX-Redirect`` so the whole page swaps; for plain form posts it returns
|
||||
a 303 redirect.
|
||||
* The fragment endpoint powers the polling loop: while the job is
|
||||
pending/running, the fragment auto-refreshes every 2s via
|
||||
``hx-trigger="every 2s"``; when terminal, the trigger is dropped and the
|
||||
pretty-printed response is rendered with highlight.js.
|
||||
* A process-wide 60-second cache of the OCR GPU flag (read from the
|
||||
injected :class:`Probes`) gates a "Surya is running on CPU" notice on
|
||||
the fragment. The fragment is polled every 2 s; re-probing the OCR
|
||||
client on every poll is waste — one probe per minute is plenty.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from urllib.parse import unquote, urlencode, urlsplit
|
||||
from uuid import UUID
|
||||
|
||||
import aiofiles
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Query,
|
||||
Request,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, Response
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from ix.adapters.rest.routes import Probes, get_probes, get_session_factory_dep
|
||||
from ix.config import AppConfig, get_config
|
||||
from ix.contracts.request import (
|
||||
Context,
|
||||
FileRef,
|
||||
GenAIOptions,
|
||||
InlineUseCase,
|
||||
OCROptions,
|
||||
Options,
|
||||
ProvenanceOptions,
|
||||
RequestIX,
|
||||
UseCaseFieldDef,
|
||||
)
|
||||
from ix.store import jobs_repo
|
||||
from ix.use_cases import REGISTRY
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).parent / "templates"
|
||||
STATIC_DIR = Path(__file__).parent / "static"
|
||||
|
||||
# Module-level cache for the OCR GPU flag. The tuple is ``(value, expires_at)``
|
||||
# where ``expires_at`` is a monotonic-clock deadline. A per-request call to
|
||||
# :func:`_cached_ocr_gpu` re-probes only once the deadline has passed.
|
||||
_OCR_GPU_CACHE: tuple[bool | None, float] = (None, 0.0)
|
||||
_OCR_GPU_TTL_SECONDS = 60.0
|
||||
|
||||
|
||||
def _templates() -> Jinja2Templates:
|
||||
"""One Jinja env per process; cheap enough to build per DI call."""
|
||||
|
||||
return Jinja2Templates(directory=str(TEMPLATES_DIR))
|
||||
|
||||
|
||||
def _ui_tmp_dir(cfg: AppConfig) -> Path:
|
||||
"""Where uploads land. Created on first use; never cleaned up."""
|
||||
|
||||
d = Path(cfg.tmp_dir) / "ui"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
def _cached_ocr_gpu(probes: Probes) -> bool | None:
|
||||
"""Read the cached OCR GPU flag, re-probing if the TTL has elapsed.
|
||||
|
||||
Used by the index + fragment routes so the HTMX poll loop doesn't hit
|
||||
the OCR client's torch-probe every 2 seconds. Falls back to ``None``
|
||||
(unknown) on any probe error.
|
||||
"""
|
||||
|
||||
global _OCR_GPU_CACHE
|
||||
value, expires_at = _OCR_GPU_CACHE
|
||||
now = time.monotonic()
|
||||
if now >= expires_at:
|
||||
try:
|
||||
value = probes.ocr_gpu()
|
||||
except Exception:
|
||||
value = None
|
||||
_OCR_GPU_CACHE = (value, now + _OCR_GPU_TTL_SECONDS)
|
||||
return value
|
||||
|
||||
|
||||
_VALID_STATUSES = ("pending", "running", "done", "error")
|
||||
_JOBS_LIST_DEFAULT_LIMIT = 50
|
||||
_JOBS_LIST_MAX_LIMIT = 200
|
||||
|
||||
|
||||
def _use_case_label(request: RequestIX | None) -> str:
|
||||
"""Prefer inline use-case label, fall back to the registered name."""
|
||||
|
||||
if request is None:
|
||||
return "—"
|
||||
if request.use_case_inline is not None:
|
||||
return request.use_case_inline.use_case_name or request.use_case
|
||||
return request.use_case or "—"
|
||||
|
||||
|
||||
def _row_elapsed_seconds(job) -> int | None: # type: ignore[no-untyped-def]
|
||||
"""Wall-clock seconds for a terminal row (finished - started).
|
||||
|
||||
Used in the list view's "Elapsed" column. Returns ``None`` for rows
|
||||
that haven't run yet (pending / running-with-missing-started_at) so
|
||||
the template can render ``—`` instead.
|
||||
"""
|
||||
|
||||
if job.status in ("done", "error") and job.started_at and job.finished_at:
|
||||
return max(0, int((job.finished_at - job.started_at).total_seconds()))
|
||||
return None
|
||||
|
||||
|
||||
def _humanize_delta(seconds: int) -> str:
|
||||
"""Coarse-grained "N min ago" for the list view.
|
||||
|
||||
The list renders many rows; we don't need second-accuracy here. For
|
||||
sub-minute values we still say "just now" to avoid a jumpy display.
|
||||
"""
|
||||
|
||||
if seconds < 45:
|
||||
return "just now"
|
||||
mins = seconds // 60
|
||||
if mins < 60:
|
||||
return f"{mins} min ago"
|
||||
hours = mins // 60
|
||||
if hours < 24:
|
||||
return f"{hours} h ago"
|
||||
days = hours // 24
|
||||
return f"{days} d ago"
|
||||
|
||||
|
||||
def _fmt_elapsed_seconds(seconds: int | None) -> str:
|
||||
if seconds is None:
|
||||
return "—"
|
||||
return f"{seconds // 60:02d}:{seconds % 60:02d}"
|
||||
|
||||
|
||||
def _file_display_entries(
|
||||
request: RequestIX | None,
|
||||
) -> list[str]:
|
||||
"""Human-readable filename(s) for a request's context.files.
|
||||
|
||||
Prefers :attr:`FileRef.display_name`. Falls back to the URL's basename
|
||||
(``unquote``ed so ``%20`` → space). Plain string entries use the same
|
||||
basename rule. Empty list for a request with no files.
|
||||
"""
|
||||
|
||||
if request is None:
|
||||
return []
|
||||
out: list[str] = []
|
||||
for entry in request.context.files:
|
||||
if isinstance(entry, FileRef):
|
||||
if entry.display_name:
|
||||
out.append(entry.display_name)
|
||||
continue
|
||||
url = entry.url
|
||||
else:
|
||||
url = entry
|
||||
basename = unquote(urlsplit(url).path.rsplit("/", 1)[-1]) or url
|
||||
out.append(basename)
|
||||
return out
|
||||
|
||||
|
||||
def build_router() -> APIRouter:
|
||||
"""Return a fresh router. Kept as a factory so :mod:`ix.app` can wire DI."""
|
||||
|
||||
router = APIRouter(prefix="/ui", tags=["ui"])
|
||||
|
||||
@router.get("", response_class=HTMLResponse)
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
async def index(
|
||||
request: Request,
|
||||
probes: Annotated[Probes, Depends(get_probes)],
|
||||
) -> Response:
|
||||
tpl = _templates()
|
||||
return tpl.TemplateResponse(
|
||||
request,
|
||||
"index.html",
|
||||
{
|
||||
"registered_use_cases": sorted(REGISTRY.keys()),
|
||||
"job": None,
|
||||
"form_error": None,
|
||||
"form_values": {},
|
||||
"file_names": [],
|
||||
"cpu_mode": _cached_ocr_gpu(probes) is False,
|
||||
},
|
||||
)
|
||||
|
||||
@router.get("/jobs", response_class=HTMLResponse)
|
||||
async def jobs_list(
|
||||
request: Request,
|
||||
session_factory: Annotated[
|
||||
async_sessionmaker[AsyncSession], Depends(get_session_factory_dep)
|
||||
],
|
||||
status: Annotated[list[str] | None, Query()] = None,
|
||||
client_id: Annotated[str | None, Query()] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=_JOBS_LIST_MAX_LIMIT)] = _JOBS_LIST_DEFAULT_LIMIT,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> Response:
|
||||
# Drop unknown statuses silently — we don't want a stray query
|
||||
# param to 400. The filter bar only offers valid values anyway.
|
||||
status_filter: list[str] = []
|
||||
if status:
|
||||
status_filter = [s for s in status if s in _VALID_STATUSES]
|
||||
client_filter = (client_id or "").strip() or None
|
||||
|
||||
async with session_factory() as session:
|
||||
jobs, total = await jobs_repo.list_recent(
|
||||
session,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
status=status_filter if status_filter else None,
|
||||
client_id=client_filter,
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
rows = []
|
||||
for job in jobs:
|
||||
files = _file_display_entries(job.request)
|
||||
display = files[0] if files else "—"
|
||||
created = job.created_at
|
||||
created_delta = _humanize_delta(
|
||||
int((now - created).total_seconds())
|
||||
) if created is not None else "—"
|
||||
created_local = (
|
||||
created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if created is not None
|
||||
else "—"
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"job_id": str(job.job_id),
|
||||
"status": job.status,
|
||||
"display_name": display,
|
||||
"use_case": _use_case_label(job.request),
|
||||
"client_id": job.client_id,
|
||||
"created_at": created_local,
|
||||
"created_delta": created_delta,
|
||||
"elapsed": _fmt_elapsed_seconds(_row_elapsed_seconds(job)),
|
||||
}
|
||||
)
|
||||
|
||||
prev_offset = max(0, offset - limit) if offset > 0 else None
|
||||
next_offset = offset + limit if (offset + limit) < total else None
|
||||
|
||||
def _link(new_offset: int) -> str:
|
||||
params: list[tuple[str, str]] = []
|
||||
for s in status_filter:
|
||||
params.append(("status", s))
|
||||
if client_filter:
|
||||
params.append(("client_id", client_filter))
|
||||
params.append(("limit", str(limit)))
|
||||
params.append(("offset", str(new_offset)))
|
||||
return f"/ui/jobs?{urlencode(params)}"
|
||||
|
||||
tpl = _templates()
|
||||
return tpl.TemplateResponse(
|
||||
request,
|
||||
"jobs_list.html",
|
||||
{
|
||||
"rows": rows,
|
||||
"total": total,
|
||||
"shown": len(rows),
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"status_filter": status_filter,
|
||||
"client_filter": client_filter or "",
|
||||
"valid_statuses": _VALID_STATUSES,
|
||||
"prev_link": _link(prev_offset) if prev_offset is not None else None,
|
||||
"next_link": _link(next_offset) if next_offset is not None else None,
|
||||
},
|
||||
)
|
||||
|
||||
@router.get("/jobs/{job_id}", response_class=HTMLResponse)
|
||||
async def job_page(
|
||||
request: Request,
|
||||
job_id: UUID,
|
||||
session_factory: Annotated[
|
||||
async_sessionmaker[AsyncSession], Depends(get_session_factory_dep)
|
||||
],
|
||||
probes: Annotated[Probes, Depends(get_probes)],
|
||||
) -> Response:
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.get(session, job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="job not found")
|
||||
tpl = _templates()
|
||||
return tpl.TemplateResponse(
|
||||
request,
|
||||
"index.html",
|
||||
{
|
||||
"registered_use_cases": sorted(REGISTRY.keys()),
|
||||
"job": job,
|
||||
"form_error": None,
|
||||
"form_values": {},
|
||||
"file_names": _file_display_entries(job.request),
|
||||
"cpu_mode": _cached_ocr_gpu(probes) is False,
|
||||
},
|
||||
)
|
||||
|
||||
@router.get("/jobs/{job_id}/fragment", response_class=HTMLResponse)
|
||||
async def job_fragment(
|
||||
request: Request,
|
||||
job_id: UUID,
|
||||
session_factory: Annotated[
|
||||
async_sessionmaker[AsyncSession], Depends(get_session_factory_dep)
|
||||
],
|
||||
probes: Annotated[Probes, Depends(get_probes)],
|
||||
) -> Response:
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.get(session, job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="job not found")
|
||||
ahead, total_active = await jobs_repo.queue_position(
|
||||
session, job_id
|
||||
)
|
||||
|
||||
response_json: str | None = None
|
||||
if job.response is not None:
|
||||
response_json = json.dumps(
|
||||
job.response.model_dump(mode="json"),
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
default=str,
|
||||
)
|
||||
|
||||
elapsed_text = _format_elapsed(job)
|
||||
file_names = _file_display_entries(job.request)
|
||||
|
||||
tpl = _templates()
|
||||
return tpl.TemplateResponse(
|
||||
request,
|
||||
"job_fragment.html",
|
||||
{
|
||||
"job": job,
|
||||
"response_json": response_json,
|
||||
"ahead": ahead,
|
||||
"total_active": total_active,
|
||||
"elapsed_text": elapsed_text,
|
||||
"file_names": file_names,
|
||||
"cpu_mode": _cached_ocr_gpu(probes) is False,
|
||||
},
|
||||
)
|
||||
|
||||
@router.post("/jobs")
|
||||
async def submit_job(
|
||||
request: Request,
|
||||
session_factory: Annotated[
|
||||
async_sessionmaker[AsyncSession], Depends(get_session_factory_dep)
|
||||
],
|
||||
pdf: Annotated[UploadFile, File()],
|
||||
use_case_name: Annotated[str, Form()],
|
||||
use_case_mode: Annotated[str, Form()] = "registered",
|
||||
texts: Annotated[str, Form()] = "",
|
||||
ix_client_id: Annotated[str, Form()] = "ui",
|
||||
request_id: Annotated[str, Form()] = "",
|
||||
system_prompt: Annotated[str, Form()] = "",
|
||||
default_model: Annotated[str, Form()] = "",
|
||||
fields_json: Annotated[str, Form()] = "",
|
||||
use_ocr: Annotated[str, Form()] = "",
|
||||
ocr_only: Annotated[str, Form()] = "",
|
||||
include_ocr_text: Annotated[str, Form()] = "",
|
||||
include_geometries: Annotated[str, Form()] = "",
|
||||
gen_ai_model_name: Annotated[str, Form()] = "",
|
||||
include_provenance: Annotated[str, Form()] = "",
|
||||
max_sources_per_field: Annotated[str, Form()] = "10",
|
||||
) -> Response:
|
||||
cfg = get_config()
|
||||
form_values = {
|
||||
"use_case_mode": use_case_mode,
|
||||
"use_case_name": use_case_name,
|
||||
"ix_client_id": ix_client_id,
|
||||
"request_id": request_id,
|
||||
"texts": texts,
|
||||
"system_prompt": system_prompt,
|
||||
"default_model": default_model,
|
||||
"fields_json": fields_json,
|
||||
"use_ocr": use_ocr,
|
||||
"ocr_only": ocr_only,
|
||||
"include_ocr_text": include_ocr_text,
|
||||
"include_geometries": include_geometries,
|
||||
"gen_ai_model_name": gen_ai_model_name,
|
||||
"include_provenance": include_provenance,
|
||||
"max_sources_per_field": max_sources_per_field,
|
||||
}
|
||||
|
||||
def _rerender(error: str, status: int = 200) -> Response:
|
||||
tpl = _templates()
|
||||
return tpl.TemplateResponse(
|
||||
request,
|
||||
"index.html",
|
||||
{
|
||||
"registered_use_cases": sorted(REGISTRY.keys()),
|
||||
"job": None,
|
||||
"form_error": error,
|
||||
"form_values": form_values,
|
||||
},
|
||||
status_code=status,
|
||||
)
|
||||
|
||||
# --- Inline use case (optional) ---
|
||||
inline: InlineUseCase | None = None
|
||||
if use_case_mode == "custom":
|
||||
try:
|
||||
raw_fields = json.loads(fields_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
return _rerender(f"Invalid fields JSON: {exc}", status=422)
|
||||
if not isinstance(raw_fields, list):
|
||||
return _rerender(
|
||||
"Invalid fields JSON: must be a list of field objects",
|
||||
status=422,
|
||||
)
|
||||
try:
|
||||
parsed = [UseCaseFieldDef.model_validate(f) for f in raw_fields]
|
||||
inline = InlineUseCase(
|
||||
use_case_name=use_case_name,
|
||||
system_prompt=system_prompt,
|
||||
default_model=default_model or None,
|
||||
fields=parsed,
|
||||
)
|
||||
except Exception as exc: # pydantic ValidationError or similar
|
||||
return _rerender(
|
||||
f"Invalid inline use-case definition: {exc}",
|
||||
status=422,
|
||||
)
|
||||
|
||||
# --- PDF upload ---
|
||||
upload_dir = _ui_tmp_dir(cfg)
|
||||
target = upload_dir / f"{uuid.uuid4().hex}.pdf"
|
||||
# Stream copy with a size cap matching IX_FILE_MAX_BYTES.
|
||||
total = 0
|
||||
limit = cfg.file_max_bytes
|
||||
async with aiofiles.open(target, "wb") as out:
|
||||
while True:
|
||||
chunk = await pdf.read(64 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > limit:
|
||||
# Drop the partial file; no stored state.
|
||||
from contextlib import suppress
|
||||
|
||||
with suppress(FileNotFoundError):
|
||||
target.unlink()
|
||||
return _rerender(
|
||||
f"PDF exceeds IX_FILE_MAX_BYTES ({limit} bytes)",
|
||||
status=413,
|
||||
)
|
||||
await out.write(chunk)
|
||||
|
||||
# --- Build RequestIX ---
|
||||
ctx_texts: list[str] = []
|
||||
if texts.strip():
|
||||
ctx_texts = [texts.strip()]
|
||||
|
||||
req_id = request_id.strip() or uuid.uuid4().hex
|
||||
# Preserve the client-provided filename so the UI can surface the
|
||||
# original name to the user (the on-disk name is a UUID). Strip any
|
||||
# path prefix a browser included.
|
||||
original_name = (pdf.filename or "").rsplit("/", 1)[-1].rsplit(
|
||||
"\\", 1
|
||||
)[-1] or None
|
||||
try:
|
||||
request_ix = RequestIX(
|
||||
use_case=use_case_name or "adhoc",
|
||||
use_case_inline=inline,
|
||||
ix_client_id=(ix_client_id.strip() or "ui"),
|
||||
request_id=req_id,
|
||||
context=Context(
|
||||
files=[
|
||||
FileRef(
|
||||
url=f"file://{target.resolve()}",
|
||||
display_name=original_name,
|
||||
)
|
||||
],
|
||||
texts=ctx_texts,
|
||||
),
|
||||
options=Options(
|
||||
ocr=OCROptions(
|
||||
use_ocr=_flag(use_ocr, default=True),
|
||||
ocr_only=_flag(ocr_only, default=False),
|
||||
include_ocr_text=_flag(include_ocr_text, default=False),
|
||||
include_geometries=_flag(include_geometries, default=False),
|
||||
),
|
||||
gen_ai=GenAIOptions(
|
||||
gen_ai_model_name=(gen_ai_model_name.strip() or None),
|
||||
),
|
||||
provenance=ProvenanceOptions(
|
||||
include_provenance=_flag(include_provenance, default=True),
|
||||
max_sources_per_field=int(max_sources_per_field or 10),
|
||||
),
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
return _rerender(f"Invalid request: {exc}", status=422)
|
||||
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.insert_pending(
|
||||
session, request_ix, callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
redirect_to = f"/ui/jobs/{job.job_id}"
|
||||
if request.headers.get("HX-Request", "").lower() == "true":
|
||||
return Response(status_code=200, headers={"HX-Redirect": redirect_to})
|
||||
return RedirectResponse(url=redirect_to, status_code=303)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def _flag(value: str, *, default: bool) -> bool:
|
||||
"""HTML forms omit unchecked checkboxes. Treat absence as ``default``."""
|
||||
|
||||
if value == "":
|
||||
return default
|
||||
return value.lower() in ("on", "true", "1", "yes")
|
||||
|
||||
|
||||
def _format_elapsed(job) -> str | None: # type: ignore[no-untyped-def]
|
||||
"""Render a ``MM:SS`` elapsed string for the fragment template.
|
||||
|
||||
* running → time since ``started_at``
|
||||
* done/error → ``finished_at - created_at`` (total wall-clock including
|
||||
queue time)
|
||||
* pending / missing timestamps → ``None`` (template omits the line)
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
def _fmt(seconds: float) -> str:
|
||||
s = max(0, int(seconds))
|
||||
return f"{s // 60:02d}:{s % 60:02d}"
|
||||
|
||||
if job.status == "running" and job.started_at is not None:
|
||||
now = datetime.now(UTC)
|
||||
return _fmt((now - job.started_at).total_seconds())
|
||||
if (
|
||||
job.status in ("done", "error")
|
||||
and job.finished_at is not None
|
||||
and job.created_at is not None
|
||||
):
|
||||
return _fmt((job.finished_at - job.created_at).total_seconds())
|
||||
return None
|
||||
0
src/ix/ui/static/.gitkeep
Normal file
0
src/ix/ui/static/.gitkeep
Normal file
242
src/ix/ui/templates/index.html
Normal file
242
src/ix/ui/templates/index.html
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
<!doctype html>
|
||||
<html lang="en" data-theme="light">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>
|
||||
InfoXtractor{% if job %} — job {{ job.job_id }}{% endif %}
|
||||
</title>
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.min.css"
|
||||
/>
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/atom-one-light.min.css"
|
||||
/>
|
||||
<script src="https://unpkg.com/htmx.org@1.9.12"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
|
||||
<style>
|
||||
main { padding-top: 1.5rem; padding-bottom: 4rem; }
|
||||
pre code.hljs { padding: 1rem; border-radius: 0.4rem; }
|
||||
.form-error { color: var(--pico-del-color, #c44); font-weight: 600; }
|
||||
details[open] > summary { margin-bottom: 0.5rem; }
|
||||
.field-hint { font-size: 0.85rem; color: var(--pico-muted-color); }
|
||||
nav.ix-header {
|
||||
display: flex; gap: 1rem; align-items: baseline;
|
||||
padding: 0.6rem 0; border-bottom: 1px solid var(--pico-muted-border-color, #ddd);
|
||||
margin-bottom: 1rem; flex-wrap: wrap;
|
||||
}
|
||||
nav.ix-header .brand { font-weight: 700; margin-right: auto; }
|
||||
nav.ix-header code { font-size: 0.9em; }
|
||||
.status-panel, .result-panel { margin-top: 0.75rem; }
|
||||
.status-panel header, .result-panel header { font-size: 0.95rem; }
|
||||
.job-files code { font-size: 0.9em; }
|
||||
.cpu-notice { margin-top: 0.6rem; font-size: 0.9rem; color: var(--pico-muted-color); }
|
||||
.live-dot {
|
||||
display: inline-block; margin-left: 0.3rem;
|
||||
animation: ix-blink 1.2s ease-in-out infinite;
|
||||
color: var(--pico-primary, #4f8cc9);
|
||||
}
|
||||
@keyframes ix-blink {
|
||||
0%, 100% { opacity: 0.2; }
|
||||
50% { opacity: 1; }
|
||||
}
|
||||
.copy-btn {
|
||||
margin-left: 0.3rem; padding: 0.1rem 0.5rem;
|
||||
font-size: 0.8rem; line-height: 1.2;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main class="container">
|
||||
<nav class="ix-header" aria-label="InfoXtractor navigation">
|
||||
<span class="brand">InfoXtractor</span>
|
||||
<a href="/ui">Upload a new extraction</a>
|
||||
<a href="/ui/jobs">Recent jobs</a>
|
||||
{% if job %}
|
||||
<span>
|
||||
Job:
|
||||
<code id="current-job-id">{{ job.job_id }}</code>
|
||||
<button
|
||||
type="button"
|
||||
class="secondary outline copy-btn"
|
||||
onclick="navigator.clipboard && navigator.clipboard.writeText('{{ job.job_id }}')"
|
||||
aria-label="Copy job id to clipboard"
|
||||
>Copy</button>
|
||||
</span>
|
||||
{% endif %}
|
||||
</nav>
|
||||
|
||||
<hgroup>
|
||||
<h1>infoxtractor</h1>
|
||||
<p>Drop a PDF, pick or define a use case, run the pipeline.</p>
|
||||
</hgroup>
|
||||
|
||||
{% if form_error %}
|
||||
<article class="form-error">
|
||||
<p><strong>Form error:</strong> {{ form_error }}</p>
|
||||
</article>
|
||||
{% endif %}
|
||||
|
||||
{% if not job %}
|
||||
<article>
|
||||
<form
|
||||
action="/ui/jobs"
|
||||
method="post"
|
||||
enctype="multipart/form-data"
|
||||
hx-post="/ui/jobs"
|
||||
hx-encoding="multipart/form-data"
|
||||
>
|
||||
<label>
|
||||
PDF file
|
||||
<input type="file" name="pdf" accept="application/pdf" required />
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Extra texts (optional, e.g. Paperless OCR output)
|
||||
<textarea
|
||||
name="texts"
|
||||
rows="3"
|
||||
placeholder="Plain text passed as context.texts[0]"
|
||||
>{{ form_values.get("texts", "") }}</textarea>
|
||||
<small class="field-hint">Whatever you type is submitted as a single entry in <code>context.texts</code>.</small>
|
||||
</label>
|
||||
|
||||
<fieldset>
|
||||
<legend>Use case</legend>
|
||||
<label>
|
||||
<input
|
||||
type="radio"
|
||||
name="use_case_mode"
|
||||
value="registered"
|
||||
{% if form_values.get("use_case_mode", "registered") == "registered" %}checked{% endif %}
|
||||
onchange="document.getElementById('custom-fields').hidden = true"
|
||||
/>
|
||||
Registered
|
||||
</label>
|
||||
<label>
|
||||
<input
|
||||
type="radio"
|
||||
name="use_case_mode"
|
||||
value="custom"
|
||||
{% if form_values.get("use_case_mode") == "custom" %}checked{% endif %}
|
||||
onchange="document.getElementById('custom-fields').hidden = false"
|
||||
/>
|
||||
Custom (inline)
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Use case name
|
||||
<input
|
||||
type="text"
|
||||
name="use_case_name"
|
||||
list="registered-use-cases"
|
||||
value="{{ form_values.get('use_case_name', 'bank_statement_header') }}"
|
||||
required
|
||||
/>
|
||||
<datalist id="registered-use-cases">
|
||||
{% for name in registered_use_cases %}
|
||||
<option value="{{ name }}"></option>
|
||||
{% endfor %}
|
||||
</datalist>
|
||||
</label>
|
||||
|
||||
<div id="custom-fields" {% if form_values.get("use_case_mode") != "custom" %}hidden{% endif %}>
|
||||
<label>
|
||||
System prompt
|
||||
<textarea name="system_prompt" rows="3">{{ form_values.get("system_prompt", "") }}</textarea>
|
||||
</label>
|
||||
<label>
|
||||
Default model (optional)
|
||||
<input
|
||||
type="text"
|
||||
name="default_model"
|
||||
value="{{ form_values.get('default_model', '') }}"
|
||||
placeholder="qwen3:14b"
|
||||
/>
|
||||
</label>
|
||||
<label>
|
||||
Fields (JSON list of {name, type, required?, choices?, description?})
|
||||
<textarea name="fields_json" rows="6" placeholder='[{"name": "vendor", "type": "str", "required": true}]'>{{ form_values.get("fields_json", "") }}</textarea>
|
||||
<small class="field-hint">Types: str, int, float, decimal, date, datetime, bool. <code>choices</code> works on <code>str</code> only.</small>
|
||||
</label>
|
||||
</div>
|
||||
</fieldset>
|
||||
|
||||
<details>
|
||||
<summary>Advanced options</summary>
|
||||
<label>
|
||||
Client id
|
||||
<input type="text" name="ix_client_id" value="{{ form_values.get('ix_client_id', 'ui') }}" />
|
||||
</label>
|
||||
<label>
|
||||
Request id (blank → random)
|
||||
<input type="text" name="request_id" value="{{ form_values.get('request_id', '') }}" />
|
||||
</label>
|
||||
|
||||
<fieldset>
|
||||
<legend>OCR</legend>
|
||||
<label><input type="checkbox" name="use_ocr" {% if form_values.get("use_ocr", "on") %}checked{% endif %} /> use_ocr</label>
|
||||
<label><input type="checkbox" name="ocr_only" {% if form_values.get("ocr_only") %}checked{% endif %} /> ocr_only</label>
|
||||
<label><input type="checkbox" name="include_ocr_text" {% if form_values.get("include_ocr_text") %}checked{% endif %} /> include_ocr_text</label>
|
||||
<label><input type="checkbox" name="include_geometries" {% if form_values.get("include_geometries") %}checked{% endif %} /> include_geometries</label>
|
||||
</fieldset>
|
||||
|
||||
<label>
|
||||
GenAI model override (optional)
|
||||
<input type="text" name="gen_ai_model_name" value="{{ form_values.get('gen_ai_model_name', '') }}" />
|
||||
</label>
|
||||
|
||||
<fieldset>
|
||||
<legend>Provenance</legend>
|
||||
<label><input type="checkbox" name="include_provenance" {% if form_values.get("include_provenance", "on") %}checked{% endif %} /> include_provenance</label>
|
||||
<label>
|
||||
max_sources_per_field
|
||||
<input type="number" name="max_sources_per_field" min="1" max="100" value="{{ form_values.get('max_sources_per_field', '10') }}" />
|
||||
</label>
|
||||
</fieldset>
|
||||
</details>
|
||||
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
</article>
|
||||
{% endif %}
|
||||
|
||||
{% if job %}
|
||||
<article id="job-panel">
|
||||
<header>
|
||||
<strong>Job</strong> <code>{{ job.job_id }}</code>
|
||||
<br /><small>ix_id: <code>{{ job.ix_id }}</code></small>
|
||||
{% if file_names %}
|
||||
<br /><small>
|
||||
File{% if file_names|length > 1 %}s{% endif %}:
|
||||
{% for name in file_names %}
|
||||
<code>{{ name }}</code>{% if not loop.last %}, {% endif %}
|
||||
{% endfor %}
|
||||
</small>
|
||||
{% endif %}
|
||||
</header>
|
||||
<div
|
||||
id="job-status"
|
||||
hx-get="/ui/jobs/{{ job.job_id }}/fragment"
|
||||
hx-trigger="load"
|
||||
hx-swap="innerHTML"
|
||||
>
|
||||
Loading…
|
||||
</div>
|
||||
</article>
|
||||
{% endif %}
|
||||
</main>
|
||||
|
||||
<script>
|
||||
document.body.addEventListener("htmx:afterSettle", () => {
|
||||
if (window.hljs) {
|
||||
document.querySelectorAll("pre code").forEach((el) => {
|
||||
try { hljs.highlightElement(el); } catch (_) { /* noop */ }
|
||||
});
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
77
src/ix/ui/templates/job_fragment.html
Normal file
77
src/ix/ui/templates/job_fragment.html
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
{#- HTMX fragment rendered into #job-status on the results panel.
|
||||
Pending/running → keep polling every 2s; terminal → render JSON. -#}
|
||||
{% set terminal = job.status in ("done", "error") %}
|
||||
<div
|
||||
id="job-fragment"
|
||||
{% if not terminal %}
|
||||
hx-get="/ui/jobs/{{ job.job_id }}/fragment"
|
||||
hx-trigger="every 2s"
|
||||
hx-swap="outerHTML"
|
||||
{% endif %}
|
||||
>
|
||||
<article class="status-panel">
|
||||
<header>
|
||||
<strong>Job status</strong>
|
||||
</header>
|
||||
|
||||
<p>
|
||||
Status:
|
||||
<strong>{{ job.status }}</strong>
|
||||
{% if not terminal %}
|
||||
<span class="live-dot" aria-hidden="true">●</span>
|
||||
{% endif %}
|
||||
</p>
|
||||
|
||||
{% if file_names %}
|
||||
<p class="job-files">
|
||||
File{% if file_names|length > 1 %}s{% endif %}:
|
||||
{% for name in file_names %}
|
||||
<code>{{ name }}</code>{% if not loop.last %}, {% endif %}
|
||||
{% endfor %}
|
||||
</p>
|
||||
{% endif %}
|
||||
|
||||
{% if job.status == "pending" %}
|
||||
<p>
|
||||
{% if ahead == 0 %}
|
||||
About to start — the worker just freed up.
|
||||
{% else %}
|
||||
Queue position: {{ ahead }} ahead — {{ total_active }} job{% if total_active != 1 %}s{% endif %} total in flight (single worker).
|
||||
{% endif %}
|
||||
</p>
|
||||
<progress></progress>
|
||||
{% elif job.status == "running" %}
|
||||
{% if elapsed_text %}
|
||||
<p>Running for {{ elapsed_text }}.</p>
|
||||
{% endif %}
|
||||
<progress></progress>
|
||||
{% elif terminal %}
|
||||
{% if elapsed_text %}
|
||||
<p>Finished in {{ elapsed_text }}.</p>
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if cpu_mode and not terminal %}
|
||||
<details class="cpu-notice">
|
||||
<summary>Surya is running on CPU (~1–2 min/page)</summary>
|
||||
<p>
|
||||
A host NVIDIA driver upgrade would unlock GPU extraction; tracked in
|
||||
<code>docs/deployment.md</code>.
|
||||
</p>
|
||||
</details>
|
||||
{% endif %}
|
||||
</article>
|
||||
|
||||
<article class="result-panel">
|
||||
<header>
|
||||
<strong>Result</strong>
|
||||
</header>
|
||||
{% if terminal and response_json %}
|
||||
<pre><code class="language-json">{{ response_json }}</code></pre>
|
||||
{% elif terminal %}
|
||||
<p><em>No response body.</em></p>
|
||||
{% else %}
|
||||
<p><em>Waiting for the pipeline to finish…</em></p>
|
||||
{% endif %}
|
||||
</article>
|
||||
</div>
|
||||
164
src/ix/ui/templates/jobs_list.html
Normal file
164
src/ix/ui/templates/jobs_list.html
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
<!doctype html>
|
||||
<html lang="en" data-theme="light">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>InfoXtractor — Recent jobs</title>
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.min.css"
|
||||
/>
|
||||
<style>
|
||||
main { padding-top: 1.5rem; padding-bottom: 4rem; }
|
||||
nav.ix-header {
|
||||
display: flex; gap: 1rem; align-items: baseline;
|
||||
padding: 0.6rem 0; border-bottom: 1px solid var(--pico-muted-border-color, #ddd);
|
||||
margin-bottom: 1rem; flex-wrap: wrap;
|
||||
}
|
||||
nav.ix-header .brand { font-weight: 700; margin-right: auto; }
|
||||
.breadcrumb {
|
||||
font-size: 0.9rem; color: var(--pico-muted-color);
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
.breadcrumb a { text-decoration: none; }
|
||||
.filter-bar {
|
||||
display: flex; flex-wrap: wrap; gap: 1rem; align-items: flex-end;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.filter-bar fieldset { margin: 0; padding: 0; border: none; }
|
||||
.filter-bar label.inline { display: inline-flex; gap: 0.3rem; align-items: center; margin-right: 0.8rem; font-weight: normal; }
|
||||
.counter { color: var(--pico-muted-color); margin-bottom: 0.5rem; }
|
||||
table.jobs-table { width: 100%; font-size: 0.92rem; }
|
||||
table.jobs-table th { white-space: nowrap; }
|
||||
table.jobs-table td { vertical-align: middle; }
|
||||
td.col-created small { color: var(--pico-muted-color); display: block; }
|
||||
.status-badge {
|
||||
display: inline-block; padding: 0.1rem 0.55rem;
|
||||
border-radius: 0.8rem; font-size: 0.78rem; font-weight: 600;
|
||||
text-transform: uppercase; letter-spacing: 0.04em;
|
||||
}
|
||||
.status-done { background: #d1f4dc; color: #1a6d35; }
|
||||
.status-error { background: #fadadd; color: #8a1d2b; }
|
||||
.status-pending, .status-running { background: #fff1c2; color: #805600; }
|
||||
.pagination {
|
||||
display: flex; gap: 0.75rem; margin-top: 1rem;
|
||||
align-items: center; flex-wrap: wrap;
|
||||
}
|
||||
.empty-note { color: var(--pico-muted-color); font-style: italic; }
|
||||
td.col-filename code { font-size: 0.9em; word-break: break-all; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main class="container">
|
||||
<nav class="ix-header" aria-label="InfoXtractor navigation">
|
||||
<span class="brand">InfoXtractor</span>
|
||||
<a href="/ui">Upload a new extraction</a>
|
||||
<a href="/ui/jobs">Recent jobs</a>
|
||||
</nav>
|
||||
|
||||
<p class="breadcrumb">
|
||||
<a href="/ui">Home</a> › Jobs
|
||||
</p>
|
||||
|
||||
<hgroup>
|
||||
<h1>Recent jobs</h1>
|
||||
<p>All submitted extractions, newest first.</p>
|
||||
</hgroup>
|
||||
|
||||
<form class="filter-bar" method="get" action="/ui/jobs">
|
||||
<fieldset>
|
||||
<legend><small>Status</small></legend>
|
||||
{% for s in valid_statuses %}
|
||||
<label class="inline">
|
||||
<input
|
||||
type="checkbox"
|
||||
name="status"
|
||||
value="{{ s }}"
|
||||
{% if s in status_filter %}checked{% endif %}
|
||||
/>
|
||||
{{ s }}
|
||||
</label>
|
||||
{% endfor %}
|
||||
</fieldset>
|
||||
<label>
|
||||
Client id
|
||||
<input
|
||||
type="text"
|
||||
name="client_id"
|
||||
value="{{ client_filter }}"
|
||||
placeholder="e.g. ui, mammon"
|
||||
/>
|
||||
</label>
|
||||
<label>
|
||||
Page size
|
||||
<input
|
||||
type="number"
|
||||
name="limit"
|
||||
min="1"
|
||||
max="200"
|
||||
value="{{ limit }}"
|
||||
/>
|
||||
</label>
|
||||
<button type="submit">Apply</button>
|
||||
</form>
|
||||
|
||||
<p class="counter">
|
||||
Showing {{ shown }} of {{ total }} job{% if total != 1 %}s{% endif %}.
|
||||
</p>
|
||||
|
||||
{% if rows %}
|
||||
<figure>
|
||||
<table class="jobs-table" role="grid">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Status</th>
|
||||
<th>Filename</th>
|
||||
<th>Use case</th>
|
||||
<th>Client</th>
|
||||
<th>Submitted</th>
|
||||
<th>Elapsed</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for row in rows %}
|
||||
<tr>
|
||||
<td>
|
||||
<span class="status-badge status-{{ row.status }}">{{ row.status }}</span>
|
||||
</td>
|
||||
<td class="col-filename"><code>{{ row.display_name }}</code></td>
|
||||
<td>{{ row.use_case }}</td>
|
||||
<td>{{ row.client_id }}</td>
|
||||
<td class="col-created">
|
||||
{{ row.created_at }}
|
||||
<small>{{ row.created_delta }}</small>
|
||||
</td>
|
||||
<td>{{ row.elapsed }}</td>
|
||||
<td>
|
||||
<a href="/ui/jobs/{{ row.job_id }}">open ›</a>
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</figure>
|
||||
{% else %}
|
||||
<p class="empty-note">No jobs match the current filters.</p>
|
||||
{% endif %}
|
||||
|
||||
<div class="pagination">
|
||||
{% if prev_link %}
|
||||
<a href="{{ prev_link }}" role="button" class="secondary outline">« Prev</a>
|
||||
{% else %}
|
||||
<span aria-disabled="true" class="secondary outline" role="button" style="opacity: 0.4;">« Prev</span>
|
||||
{% endif %}
|
||||
<span class="counter">Offset {{ offset }}</span>
|
||||
{% if next_link %}
|
||||
<a href="{{ next_link }}" role="button" class="secondary outline">Next »</a>
|
||||
{% else %}
|
||||
<span aria-disabled="true" class="secondary outline" role="button" style="opacity: 0.4;">Next »</span>
|
||||
{% endif %}
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -26,7 +26,7 @@ class Request(BaseModel):
|
|||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
use_case_name: str = "Bank Statement Header"
|
||||
default_model: str = "gpt-oss:20b"
|
||||
default_model: str = "qwen3:14b"
|
||||
system_prompt: str = (
|
||||
"You extract header metadata from a single bank or credit-card statement. "
|
||||
"Return only facts that appear in the document; leave a field null if uncertain. "
|
||||
|
|
|
|||
132
src/ix/use_cases/inline.py
Normal file
132
src/ix/use_cases/inline.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""Dynamic Pydantic class builder for caller-supplied use cases.
|
||||
|
||||
Input: an :class:`ix.contracts.request.InlineUseCase` carried on the
|
||||
:class:`~ix.contracts.request.RequestIX`.
|
||||
|
||||
Output: a fresh ``(RequestClass, ResponseClass)`` pair with the same shape
|
||||
as a registered use case. The :class:`~ix.pipeline.setup_step.SetupStep`
|
||||
calls this when ``request_ix.use_case_inline`` is set, bypassing the
|
||||
registry lookup entirely.
|
||||
|
||||
The builder returns brand-new classes on every call — safe to call per
|
||||
request, so two concurrent jobs can't step on each other's schemas even if
|
||||
they happen to share a ``use_case_name``. Validation errors map to
|
||||
``IX_001_001`` (same code the registry-miss path uses); the error is
|
||||
recoverable from the caller's perspective (fix the JSON and retry), not an
|
||||
infra problem.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import keyword
|
||||
import re
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||
|
||||
from ix.contracts.request import InlineUseCase, UseCaseFieldDef
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
|
||||
# Map the ``UseCaseFieldDef.type`` literal to concrete Python types.
|
||||
_TYPE_MAP: dict[str, type] = {
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"decimal": Decimal,
|
||||
"date": date,
|
||||
"datetime": datetime,
|
||||
"bool": bool,
|
||||
}
|
||||
|
||||
|
||||
def _fail(detail: str) -> IXException:
|
||||
return IXException(IXErrorCode.IX_001_001, detail=detail)
|
||||
|
||||
|
||||
def _valid_field_name(name: str) -> bool:
|
||||
"""Require a valid Python identifier that isn't a reserved keyword."""
|
||||
|
||||
return name.isidentifier() and not keyword.iskeyword(name)
|
||||
|
||||
|
||||
def _resolve_field_type(field: UseCaseFieldDef) -> Any:
|
||||
"""Return the annotation for a single field, with ``choices`` honoured."""
|
||||
|
||||
base = _TYPE_MAP[field.type]
|
||||
if field.choices:
|
||||
if field.type != "str":
|
||||
raise _fail(
|
||||
f"field {field.name!r}: 'choices' is only allowed for "
|
||||
f"type='str' (got {field.type!r})"
|
||||
)
|
||||
return Literal[tuple(field.choices)] # type: ignore[valid-type]
|
||||
return base
|
||||
|
||||
|
||||
def _sanitise_class_name(raw: str) -> str:
|
||||
"""``re.sub(r"\\W", "_", name)`` + ``Inline_`` prefix.
|
||||
|
||||
Keeps the generated class name debuggable (shows up in repr / tracebacks)
|
||||
while ensuring it's always a valid Python identifier.
|
||||
"""
|
||||
|
||||
return "Inline_" + re.sub(r"\W", "_", raw)
|
||||
|
||||
|
||||
def build_use_case_classes(
|
||||
inline: InlineUseCase,
|
||||
) -> tuple[type[BaseModel], type[BaseModel]]:
|
||||
"""Build a fresh ``(RequestClass, ResponseClass)`` from ``inline``.
|
||||
|
||||
* Every call returns new classes. The caller may cache if desired; the
|
||||
pipeline intentionally does not.
|
||||
* Raises :class:`~ix.errors.IXException` with code
|
||||
:attr:`~ix.errors.IXErrorCode.IX_001_001` on any structural problem
|
||||
(empty fields, bad name, dup name, bad ``choices``).
|
||||
"""
|
||||
|
||||
if not inline.fields:
|
||||
raise _fail("inline use case must define at least one field")
|
||||
|
||||
seen: set[str] = set()
|
||||
for fd in inline.fields:
|
||||
if not _valid_field_name(fd.name):
|
||||
raise _fail(f"field name {fd.name!r} is not a valid Python identifier")
|
||||
if fd.name in seen:
|
||||
raise _fail(f"duplicate field name {fd.name!r}")
|
||||
seen.add(fd.name)
|
||||
|
||||
response_fields: dict[str, Any] = {}
|
||||
for fd in inline.fields:
|
||||
annotation = _resolve_field_type(fd)
|
||||
field_info = Field(
|
||||
...,
|
||||
description=fd.description,
|
||||
) if fd.required else Field(
|
||||
default=None,
|
||||
description=fd.description,
|
||||
)
|
||||
if not fd.required:
|
||||
annotation = annotation | None
|
||||
response_fields[fd.name] = (annotation, field_info)
|
||||
|
||||
response_cls = create_model( # type: ignore[call-overload]
|
||||
_sanitise_class_name(inline.use_case_name),
|
||||
__config__=ConfigDict(extra="forbid"),
|
||||
**response_fields,
|
||||
)
|
||||
|
||||
request_cls = create_model( # type: ignore[call-overload]
|
||||
"Inline_Request_" + re.sub(r"\W", "_", inline.use_case_name),
|
||||
__config__=ConfigDict(extra="forbid"),
|
||||
use_case_name=(str, inline.use_case_name),
|
||||
system_prompt=(str, inline.system_prompt),
|
||||
default_model=(str | None, inline.default_model),
|
||||
)
|
||||
|
||||
return cast(type[BaseModel], request_cls), cast(type[BaseModel], response_cls)
|
||||
|
||||
|
||||
__all__ = ["build_use_case_classes"]
|
||||
7
src/ix/worker/__init__.py
Normal file
7
src/ix/worker/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Async worker — pulls pending rows and runs the pipeline against them.
|
||||
|
||||
The worker is one asyncio task spawned by the FastAPI lifespan (see
|
||||
``ix.app``). Single-concurrency per MVP spec (Ollama + Surya both want the
|
||||
GPU serially). Production wiring lives in Chunk 4; until then the pipeline
|
||||
factory is parameter-injected so tests pass a fakes-only Pipeline.
|
||||
"""
|
||||
44
src/ix/worker/callback.py
Normal file
44
src/ix/worker/callback.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""One-shot webhook callback delivery.
|
||||
|
||||
No retries — the caller always has ``GET /jobs/{id}`` as the authoritative
|
||||
fallback. We record the delivery outcome (``delivered`` / ``failed``) on the
|
||||
row but never change ``status`` based on it; terminal states are stable.
|
||||
|
||||
Spec §5 callback semantics: one POST, 2xx → delivered, anything else or
|
||||
exception → failed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from ix.contracts.job import Job
|
||||
|
||||
|
||||
async def deliver(
|
||||
callback_url: str,
|
||||
job: Job,
|
||||
timeout_s: int,
|
||||
) -> Literal["delivered", "failed"]:
|
||||
"""POST the full :class:`Job` body to ``callback_url``; return the outcome.
|
||||
|
||||
``timeout_s`` caps both connect and read — we don't configure them
|
||||
separately for callbacks because the endpoint is caller-supplied and we
|
||||
don't have a reason to treat slow-to-connect differently from slow-to-
|
||||
respond. Any exception (connection error, timeout, non-2xx) collapses to
|
||||
``"failed"``.
|
||||
"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout_s) as client:
|
||||
response = await client.post(
|
||||
callback_url,
|
||||
json=job.model_dump(mode="json"),
|
||||
)
|
||||
if 200 <= response.status_code < 300:
|
||||
return "delivered"
|
||||
return "failed"
|
||||
except Exception:
|
||||
return "failed"
|
||||
179
src/ix/worker/loop.py
Normal file
179
src/ix/worker/loop.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""Worker loop — claim pending rows, run pipeline, write terminal state.
|
||||
|
||||
One ``Worker`` instance per process. The loop body is:
|
||||
|
||||
1. Claim the next pending row (``FOR UPDATE SKIP LOCKED``). If none, wait
|
||||
for the notify event or the poll interval, whichever fires first.
|
||||
2. Build a fresh Pipeline via the injected factory and run it.
|
||||
3. Write the response via ``mark_done`` (spec's ``done iff error is None``
|
||||
invariant). If the pipeline itself raised (it shouldn't — steps catch
|
||||
IXException internally — but belt-and-braces), we stuff an
|
||||
``IX_002_000`` into ``response.error`` and mark_error.
|
||||
4. If the job has a ``callback_url``, POST once, record the outcome.
|
||||
|
||||
Startup pre-amble:
|
||||
|
||||
* Run ``sweep_orphans(now, 2 * IX_PIPELINE_REQUEST_TIMEOUT_SECONDS)`` once
|
||||
before the loop starts. Recovers rows left in ``running`` by a crashed
|
||||
previous process.
|
||||
|
||||
The "wait for work" hook is a callable so Task 3.6's PgQueueListener can
|
||||
plug in later without the worker needing to know anything about LISTEN.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ix.contracts.response import ResponseIX
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.pipeline.pipeline import Pipeline
|
||||
from ix.store import jobs_repo
|
||||
from ix.worker import callback as cb
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
PipelineFactory = Callable[[], Pipeline]
|
||||
WaitForWork = Callable[[float], "asyncio.Future[None] | asyncio.Task[None]"]
|
||||
|
||||
|
||||
class Worker:
|
||||
"""Single-concurrency worker loop.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_factory:
|
||||
async_sessionmaker bound to an engine on the current event loop.
|
||||
pipeline_factory:
|
||||
Zero-arg callable returning a fresh :class:`Pipeline`. In production
|
||||
this builds the real pipeline with Ollama + Surya; in tests it
|
||||
returns a Pipeline of fakes.
|
||||
poll_interval_seconds:
|
||||
Fallback poll cadence when no notify wakes us (spec: 10 s default).
|
||||
max_running_seconds:
|
||||
Threshold passed to :func:`sweep_orphans` at startup.
|
||||
Production wiring passes ``2 * IX_PIPELINE_REQUEST_TIMEOUT_SECONDS``.
|
||||
callback_timeout_seconds:
|
||||
Timeout for the webhook POST per spec §5.
|
||||
wait_for_work:
|
||||
Optional coroutine-factory. When set, the worker awaits it instead
|
||||
of ``asyncio.sleep``. Task 3.6 passes the PgQueueListener's
|
||||
notify-or-poll helper.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
pipeline_factory: PipelineFactory,
|
||||
poll_interval_seconds: float = 10.0,
|
||||
max_running_seconds: int = 5400,
|
||||
callback_timeout_seconds: int = 10,
|
||||
wait_for_work: Callable[[float], asyncio.Future[None]] | None = None,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._pipeline_factory = pipeline_factory
|
||||
self._poll_interval = poll_interval_seconds
|
||||
self._max_running_seconds = max_running_seconds
|
||||
self._callback_timeout = callback_timeout_seconds
|
||||
self._wait_for_work = wait_for_work
|
||||
|
||||
async def run(self, stop: asyncio.Event) -> None:
|
||||
"""Drive the claim-run-write-callback loop until ``stop`` is set."""
|
||||
|
||||
await self._startup_sweep()
|
||||
|
||||
while not stop.is_set():
|
||||
async with self._session_factory() as session:
|
||||
job = await jobs_repo.claim_next_pending(session)
|
||||
await session.commit()
|
||||
|
||||
if job is None:
|
||||
await self._sleep_or_wake(stop)
|
||||
continue
|
||||
|
||||
await self._run_one(job)
|
||||
|
||||
async def _startup_sweep(self) -> None:
|
||||
"""Rescue ``running`` rows left behind by a previous crash."""
|
||||
|
||||
async with self._session_factory() as session:
|
||||
await jobs_repo.sweep_orphans(
|
||||
session,
|
||||
datetime.now(UTC),
|
||||
self._max_running_seconds,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def _sleep_or_wake(self, stop: asyncio.Event) -> None:
|
||||
"""Either run the custom wait hook or sleep the poll interval.
|
||||
|
||||
We cap the wait at either the poll interval or the stop signal,
|
||||
whichever fires first — without this, a worker with no notify hook
|
||||
would happily sleep for 10 s while the outer app is trying to shut
|
||||
down.
|
||||
"""
|
||||
|
||||
stop_task = asyncio.create_task(stop.wait())
|
||||
try:
|
||||
if self._wait_for_work is not None:
|
||||
wake_task = asyncio.ensure_future(
|
||||
self._wait_for_work(self._poll_interval)
|
||||
)
|
||||
else:
|
||||
wake_task = asyncio.create_task(
|
||||
asyncio.sleep(self._poll_interval)
|
||||
)
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{stop_task, wake_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
finally:
|
||||
if not wake_task.done():
|
||||
wake_task.cancel()
|
||||
finally:
|
||||
if not stop_task.done():
|
||||
stop_task.cancel()
|
||||
|
||||
async def _run_one(self, job) -> None: # type: ignore[no-untyped-def]
|
||||
"""Run the pipeline for one job; persist the outcome + callback."""
|
||||
|
||||
pipeline = self._pipeline_factory()
|
||||
try:
|
||||
response = await pipeline.start(job.request)
|
||||
except Exception as exc:
|
||||
# The pipeline normally catches IXException itself. Non-IX
|
||||
# failures land here. We wrap the message in IX_002_000 so the
|
||||
# caller sees a stable code.
|
||||
ix_exc = IXException(IXErrorCode.IX_002_000, detail=str(exc))
|
||||
response = ResponseIX(error=str(ix_exc))
|
||||
async with self._session_factory() as session:
|
||||
await jobs_repo.mark_error(session, job.job_id, response)
|
||||
await session.commit()
|
||||
else:
|
||||
async with self._session_factory() as session:
|
||||
await jobs_repo.mark_done(session, job.job_id, response)
|
||||
await session.commit()
|
||||
|
||||
if job.callback_url:
|
||||
await self._deliver_callback(job.job_id, job.callback_url)
|
||||
|
||||
async def _deliver_callback(self, job_id, callback_url: str) -> None: # type: ignore[no-untyped-def]
|
||||
# Re-fetch the job so the callback payload reflects the final terminal
|
||||
# state + response. Cheaper than threading the freshly-marked state
|
||||
# back out of ``mark_done``, and keeps the callback body canonical.
|
||||
async with self._session_factory() as session:
|
||||
final = await jobs_repo.get(session, job_id)
|
||||
if final is None:
|
||||
return
|
||||
status = await cb.deliver(callback_url, final, self._callback_timeout)
|
||||
async with self._session_factory() as session:
|
||||
await jobs_repo.update_callback_status(session, job_id, status)
|
||||
await session.commit()
|
||||
|
||||
|
||||
98
tests/fixtures/synthetic_giro.pdf
vendored
Normal file
98
tests/fixtures/synthetic_giro.pdf
vendored
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
%PDF-1.7
|
||||
%µ¶
|
||||
% Written by MuPDF 1.27.2
|
||||
|
||||
1 0 obj
|
||||
<</Type/Catalog/Pages 2 0 R/Info<</Producer(MuPDF 1.27.2)>>>>
|
||||
endobj
|
||||
|
||||
2 0 obj
|
||||
<</Type/Pages/Count 1/Kids[4 0 R]>>
|
||||
endobj
|
||||
|
||||
3 0 obj
|
||||
<</Font<</helv 5 0 R>>>>
|
||||
endobj
|
||||
|
||||
4 0 obj
|
||||
<</Type/Page/MediaBox[0 0 595 842]/Rotate 0/Resources 3 0 R/Parent 2 0 R/Contents[6 0 R 7 0 R 8 0 R 9 0 R 10 0 R 11 0 R]>>
|
||||
endobj
|
||||
|
||||
5 0 obj
|
||||
<</Type/Font/Subtype/Type1/BaseFont/Helvetica/Encoding/WinAnsiEncoding>>
|
||||
endobj
|
||||
|
||||
6 0 obj
|
||||
<</Length 54>>
|
||||
stream
|
||||
|
||||
q
|
||||
BT
|
||||
1 0 0 1 72 770 Tm
|
||||
/helv 12 Tf [<444b42>]TJ
|
||||
ET
|
||||
Q
|
||||
|
||||
endstream
|
||||
endobj
|
||||
|
||||
7 0 obj
|
||||
<</Length 95/Filter/FlateDecode>>
|
||||
stream
|
||||
xÚˆ1
|
||||
€@û¼"?𒬞‚X6vB:±°P,´°ñýæXf†^<1A>„SL8+g4ìU×q,Ê~òÚ£ƒBpØ® @muf–-‚òÅu4
K¸Ô4l>Óä´Ð•9
|
||||
endstream
|
||||
endobj
|
||||
|
||||
8 0 obj
|
||||
<</Length 105/Filter/FlateDecode>>
|
||||
stream
|
||||
xÚe‰±
|
||||
ACûùŠùg2»3b!ØØ ÛÉ·‡…6~¿é%ÉK ò‘ËW£\4t¼å𜯯:÷®<C3B7>S<EFBFBD>jéLÏ<4C>™Õ`eÙyÌ=[¬°°pL2H° ÃÆ'þŸó2nrr—S¦Ò
|
||||
endstream
|
||||
endobj
|
||||
|
||||
9 0 obj
|
||||
<</Length 100/Filter/FlateDecode>>
|
||||
stream
|
||||
xÚ
ñ
|
||||
Â@EÑ~¾bþÀ™7»o
ˆ…`c'LR„°Á")lü~÷^Ž|å‘âjc×åtÕ<åòéÇOš»Î·²7ceç44Aç6tk¬°ð@Dô¨AX©#Ü—|É3å-Åyd
|
||||
endstream
|
||||
endobj
|
||||
|
||||
10 0 obj
|
||||
<</Length 99/Filter/FlateDecode>>
|
||||
stream
|
||||
xÚ
ˆ1
|
||||
B1û=ÅÞÀÝ÷’±lì„íÄB$-l<¿™©fìk§²ôX¦¸FóúØî5ß?Oxm~;4ê©mP{M
„ \'WQ<57>“<><E2809C><EFBFBD>IˆÖ8Þëb粫ý·V
|
||||
endstream
|
||||
endobj
|
||||
|
||||
11 0 obj
|
||||
<</Length 93/Filter/FlateDecode>>
|
||||
stream
|
||||
xÚ-ˆ;
|
||||
€@ûœ"70ŸÝl#‚ÍvB:±\±ÐÂÆó›Bó)ÆX-ú
ÝÙ®YÐ\ú¬%Ùö •$dÑMHUYš†ã%,jÃê&‡>NT
|
||||
endstream
|
||||
endobj
|
||||
|
||||
xref
|
||||
0 12
|
||||
0000000000 65535 f
|
||||
0000000042 00000 n
|
||||
0000000120 00000 n
|
||||
0000000172 00000 n
|
||||
0000000213 00000 n
|
||||
0000000352 00000 n
|
||||
0000000441 00000 n
|
||||
0000000544 00000 n
|
||||
0000000707 00000 n
|
||||
0000000881 00000 n
|
||||
0000001050 00000 n
|
||||
0000001218 00000 n
|
||||
|
||||
trailer
|
||||
<</Size 12/Root 1 0 R/ID[<C3B4C38E004FC2B6C3A0C2BF4C00C282><890F3E53B827FF9B00CB90D2895721FC>]>>
|
||||
startxref
|
||||
1380
|
||||
%%EOF
|
||||
124
tests/integration/conftest.py
Normal file
124
tests/integration/conftest.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Integration-test fixtures — real Postgres required.
|
||||
|
||||
Policy: tests that import these fixtures skip cleanly when no DB is
|
||||
configured. We check ``IX_TEST_DATABASE_URL`` first (local developer
|
||||
override, usually a disposable docker container), then ``IX_POSTGRES_URL``
|
||||
(what Forgejo Actions already sets). If neither is present the fixture
|
||||
short-circuits with ``pytest.skip`` so a developer running
|
||||
``pytest tests/unit`` in an unconfigured shell doesn't see the integration
|
||||
suite hang or raise cryptic ``OperationalError``.
|
||||
|
||||
Schema lifecycle:
|
||||
|
||||
* session scope: ``alembic upgrade head`` once, ``alembic downgrade base``
|
||||
at session end. We tried ``Base.metadata.create_all`` at first — faster,
|
||||
but it meant migrations stayed untested by the integration suite and a
|
||||
developer who broke ``001_initial_ix_jobs.py`` wouldn't find out until
|
||||
deploy. Current shape keeps migrations in the hot path.
|
||||
* per-test: ``TRUNCATE ix_jobs`` (via the ``_reset_schema`` autouse fixture)
|
||||
— faster than recreating the schema and preserves indexes/constraints so
|
||||
tests that want to assert ON a unique-violation path actually get one.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def _resolve_postgres_url() -> str | None:
|
||||
"""Pick the database URL per policy: test override → CI URL → none."""
|
||||
|
||||
return os.environ.get("IX_TEST_DATABASE_URL") or os.environ.get("IX_POSTGRES_URL")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def postgres_url() -> str:
|
||||
url = _resolve_postgres_url()
|
||||
if not url:
|
||||
pytest.skip(
|
||||
"no postgres configured — set IX_TEST_DATABASE_URL or IX_POSTGRES_URL"
|
||||
)
|
||||
return url
|
||||
|
||||
|
||||
def _run_alembic(direction: str, postgres_url: str) -> None:
|
||||
"""Invoke Alembic in a subprocess so its ``asyncio.run`` inside ``env.py``
|
||||
doesn't collide with the pytest-asyncio event loop.
|
||||
|
||||
We pass the URL via ``IX_POSTGRES_URL`` — not ``-x url=...`` — because
|
||||
percent-encoded characters in developer passwords trip up alembic's
|
||||
configparser-backed ini loader. The env var lane skips configparser.
|
||||
"""
|
||||
|
||||
env = os.environ.copy()
|
||||
env["IX_POSTGRES_URL"] = postgres_url
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "alembic", direction, "head" if direction == "upgrade" else "base"],
|
||||
cwd=REPO_ROOT,
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _prepare_schema(postgres_url: str) -> Iterator[None]:
|
||||
"""Run migrations once per session, torn down at the end.
|
||||
|
||||
pytest-asyncio creates one event loop per test (function-scoped by
|
||||
default) and asyncpg connections can't survive a loop switch. That
|
||||
forces a function-scoped engine below — but migrations are expensive,
|
||||
so we keep those session-scoped via a subprocess call (no loop
|
||||
involvement at all).
|
||||
"""
|
||||
|
||||
_run_alembic("downgrade", postgres_url)
|
||||
_run_alembic("upgrade", postgres_url)
|
||||
yield
|
||||
_run_alembic("downgrade", postgres_url)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def engine(postgres_url: str) -> AsyncIterator[AsyncEngine]:
|
||||
"""Per-test async engine.
|
||||
|
||||
Built fresh each test so its asyncpg connections live on the same loop
|
||||
as the test itself. Dispose on teardown — otherwise asyncpg leaks tasks
|
||||
into the next test's loop and we get ``got Future attached to a
|
||||
different loop`` errors on the second test in a file.
|
||||
"""
|
||||
|
||||
eng = create_async_engine(postgres_url, pool_pre_ping=True)
|
||||
try:
|
||||
yield eng
|
||||
finally:
|
||||
await eng.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
|
||||
"""Per-test session factory. ``expire_on_commit=False`` per prod parity."""
|
||||
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def _reset_schema(engine: AsyncEngine) -> None:
|
||||
"""Truncate ix_jobs between tests so each test starts from empty state."""
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("TRUNCATE ix_jobs")
|
||||
679
tests/integration/test_jobs_repo.py
Normal file
679
tests/integration/test_jobs_repo.py
Normal file
|
|
@ -0,0 +1,679 @@
|
|||
"""Integration tests for :mod:`ix.store.jobs_repo` — run against a real DB.
|
||||
|
||||
Every test exercises one repo method end-to-end. A few go further and
|
||||
concurrently spin up two sessions to demonstrate the claim query behaves
|
||||
correctly under ``SKIP LOCKED`` (two claimers should never see the same row).
|
||||
|
||||
Skipped cleanly when no Postgres is configured — see integration/conftest.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from ix.contracts.request import Context, RequestIX
|
||||
from ix.contracts.response import ResponseIX
|
||||
from ix.store import jobs_repo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
def _make_request(client: str = "mammon", request_id: str = "r-1") -> RequestIX:
|
||||
return RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id=client,
|
||||
request_id=request_id,
|
||||
context=Context(texts=["hello"]),
|
||||
)
|
||||
|
||||
|
||||
async def test_insert_pending_creates_row_and_assigns_ix_id(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.insert_pending(
|
||||
session, _make_request(), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert job.status == "pending"
|
||||
assert isinstance(job.job_id, UUID)
|
||||
# ix_id is a 16-hex string per spec §3 — transport-assigned.
|
||||
assert isinstance(job.ix_id, str)
|
||||
assert len(job.ix_id) == 16
|
||||
assert all(c in "0123456789abcdef" for c in job.ix_id)
|
||||
assert job.attempts == 0
|
||||
|
||||
|
||||
async def test_insert_pending_is_idempotent_on_correlation_key(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""(client_id, request_id) collides → existing row comes back unchanged."""
|
||||
|
||||
async with session_factory() as session:
|
||||
first = await jobs_repo.insert_pending(
|
||||
session, _make_request("mammon", "same-id"), callback_url="http://x/cb"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
second = await jobs_repo.insert_pending(
|
||||
session, _make_request("mammon", "same-id"), callback_url="http://y/cb"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert second.job_id == first.job_id
|
||||
assert second.ix_id == first.ix_id
|
||||
# The callback_url of the FIRST insert wins — we don't overwrite.
|
||||
assert second.callback_url == "http://x/cb"
|
||||
|
||||
|
||||
async def test_get_returns_full_job(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
inserted = await jobs_repo.insert_pending(
|
||||
session, _make_request(), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
fetched = await jobs_repo.get(session, inserted.job_id)
|
||||
|
||||
assert fetched is not None
|
||||
assert fetched.job_id == inserted.job_id
|
||||
assert fetched.request.use_case == "bank_statement_header"
|
||||
assert fetched.status == "pending"
|
||||
|
||||
|
||||
async def test_get_unknown_id_returns_none(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
result = await jobs_repo.get(session, uuid4())
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_get_by_correlation(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
inserted = await jobs_repo.insert_pending(
|
||||
session, _make_request("mammon", "req-42"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
found = await jobs_repo.get_by_correlation(session, "mammon", "req-42")
|
||||
assert found is not None
|
||||
assert found.job_id == inserted.job_id
|
||||
|
||||
async with session_factory() as session:
|
||||
missing = await jobs_repo.get_by_correlation(session, "mammon", "nope")
|
||||
assert missing is None
|
||||
|
||||
|
||||
async def test_claim_next_pending_advances_status(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
inserted = await jobs_repo.insert_pending(
|
||||
session, _make_request(), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
claimed = await jobs_repo.claim_next_pending(session)
|
||||
await session.commit()
|
||||
|
||||
assert claimed is not None
|
||||
assert claimed.job_id == inserted.job_id
|
||||
assert claimed.status == "running"
|
||||
assert claimed.started_at is not None
|
||||
|
||||
|
||||
async def test_claim_next_pending_returns_none_when_empty(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
claimed = await jobs_repo.claim_next_pending(session)
|
||||
await session.commit()
|
||||
assert claimed is None
|
||||
|
||||
|
||||
async def test_claim_next_pending_skips_locked(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""Two concurrent claimers pick different rows (SKIP LOCKED in action)."""
|
||||
|
||||
async with session_factory() as session:
|
||||
a = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", "a"), callback_url=None
|
||||
)
|
||||
b = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", "b"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
session_a = session_factory()
|
||||
session_b = session_factory()
|
||||
try:
|
||||
# Start the first claim but *don't* commit yet — its row is locked.
|
||||
first = await jobs_repo.claim_next_pending(session_a)
|
||||
# Second claimer runs while the first is still holding its lock. It
|
||||
# must see the 'a' row as pending but SKIP it, returning the 'b' row.
|
||||
second = await jobs_repo.claim_next_pending(session_b)
|
||||
|
||||
assert first is not None and second is not None
|
||||
assert {first.job_id, second.job_id} == {a.job_id, b.job_id}
|
||||
assert first.job_id != second.job_id
|
||||
|
||||
await session_a.commit()
|
||||
await session_b.commit()
|
||||
finally:
|
||||
await session_a.close()
|
||||
await session_b.close()
|
||||
|
||||
|
||||
async def test_mark_done_writes_response_and_finishes(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
inserted = await jobs_repo.insert_pending(
|
||||
session, _make_request(), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
response = ResponseIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="mammon",
|
||||
request_id="r-1",
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
await jobs_repo.mark_done(session, inserted.job_id, response)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
after = await jobs_repo.get(session, inserted.job_id)
|
||||
assert after is not None
|
||||
assert after.status == "done"
|
||||
assert after.response is not None
|
||||
assert after.finished_at is not None
|
||||
|
||||
|
||||
async def test_mark_done_with_error_response_moves_to_error(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""`done` iff response.error is None — otherwise status='error'."""
|
||||
|
||||
async with session_factory() as session:
|
||||
inserted = await jobs_repo.insert_pending(
|
||||
session, _make_request(), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
bad = ResponseIX(error="IX_002_000: boom")
|
||||
|
||||
async with session_factory() as session:
|
||||
await jobs_repo.mark_done(session, inserted.job_id, bad)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
after = await jobs_repo.get(session, inserted.job_id)
|
||||
assert after is not None
|
||||
assert after.status == "error"
|
||||
assert after.response is not None
|
||||
assert (after.response.error or "").startswith("IX_002_000")
|
||||
|
||||
|
||||
async def test_mark_error_always_error(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
inserted = await jobs_repo.insert_pending(
|
||||
session, _make_request(), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
bad = ResponseIX(error="IX_000_005: unsupported")
|
||||
|
||||
async with session_factory() as session:
|
||||
await jobs_repo.mark_error(session, inserted.job_id, bad)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
after = await jobs_repo.get(session, inserted.job_id)
|
||||
assert after is not None
|
||||
assert after.status == "error"
|
||||
|
||||
|
||||
async def test_update_callback_status(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
inserted = await jobs_repo.insert_pending(
|
||||
session, _make_request(), callback_url="http://cb"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
await jobs_repo.update_callback_status(session, inserted.job_id, "delivered")
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
after = await jobs_repo.get(session, inserted.job_id)
|
||||
assert after is not None
|
||||
assert after.callback_status == "delivered"
|
||||
|
||||
|
||||
async def test_sweep_orphans_resets_stale_running(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""Running rows older than (now - max_running_seconds) go back to pending."""
|
||||
|
||||
async with session_factory() as session:
|
||||
inserted = await jobs_repo.insert_pending(
|
||||
session, _make_request(), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Backdate started_at by an hour to simulate a crashed worker mid-job.
|
||||
async with session_factory() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
stale = datetime.now(UTC) - timedelta(hours=1)
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE ix_jobs SET status='running', started_at=:t "
|
||||
"WHERE job_id=:jid"
|
||||
),
|
||||
{"t": stale, "jid": inserted.job_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Max age of 60 s → our hour-old row gets swept.
|
||||
async with session_factory() as session:
|
||||
rescued = await jobs_repo.sweep_orphans(
|
||||
session, datetime.now(UTC), max_running_seconds=60
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert inserted.job_id in rescued
|
||||
async with session_factory() as session:
|
||||
after = await jobs_repo.get(session, inserted.job_id)
|
||||
assert after is not None
|
||||
assert after.status == "pending"
|
||||
assert after.attempts == 1
|
||||
|
||||
|
||||
async def test_sweep_orphans_leaves_fresh_running_alone(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""A just-claimed row must not get reclaimed by the sweeper."""
|
||||
|
||||
async with session_factory() as session:
|
||||
await jobs_repo.insert_pending(session, _make_request(), callback_url=None)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
claimed = await jobs_repo.claim_next_pending(session)
|
||||
await session.commit()
|
||||
assert claimed is not None
|
||||
|
||||
# Sweep with a huge threshold (1 hour). Our just-claimed row is fresh, so
|
||||
# it stays running.
|
||||
async with session_factory() as session:
|
||||
rescued = await jobs_repo.sweep_orphans(
|
||||
session, datetime.now(UTC), max_running_seconds=3600
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert rescued == []
|
||||
|
||||
async with session_factory() as session:
|
||||
after = await jobs_repo.get(session, claimed.job_id)
|
||||
assert after is not None
|
||||
assert after.status == "running"
|
||||
|
||||
|
||||
async def test_queue_position_pending_only(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""Three pending rows in insertion order → positions 0, 1, 2; total 3.
|
||||
|
||||
Each row is committed in its own transaction so the DB stamps a
|
||||
distinct ``created_at`` per row (``now()`` is transaction-stable).
|
||||
"""
|
||||
|
||||
async with session_factory() as session:
|
||||
a = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", "qp-a"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
async with session_factory() as session:
|
||||
b = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", "qp-b"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
async with session_factory() as session:
|
||||
c = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", "qp-c"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
pa = await jobs_repo.queue_position(session, a.job_id)
|
||||
pb = await jobs_repo.queue_position(session, b.job_id)
|
||||
pc = await jobs_repo.queue_position(session, c.job_id)
|
||||
|
||||
# All three active; total == 3.
|
||||
assert pa == (0, 3)
|
||||
assert pb == (1, 3)
|
||||
assert pc == (2, 3)
|
||||
|
||||
|
||||
async def test_queue_position_running_plus_pending(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""One running + two pending → running:(0,3), next:(1,3), last:(2,3)."""
|
||||
|
||||
async with session_factory() as session:
|
||||
first = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", "qp-r-1"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
async with session_factory() as session:
|
||||
second = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", "qp-r-2"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
async with session_factory() as session:
|
||||
third = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", "qp-r-3"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Claim the first → it becomes running.
|
||||
async with session_factory() as session:
|
||||
claimed = await jobs_repo.claim_next_pending(session)
|
||||
await session.commit()
|
||||
assert claimed is not None
|
||||
assert claimed.job_id == first.job_id
|
||||
|
||||
async with session_factory() as session:
|
||||
p_running = await jobs_repo.queue_position(session, first.job_id)
|
||||
p_second = await jobs_repo.queue_position(session, second.job_id)
|
||||
p_third = await jobs_repo.queue_position(session, third.job_id)
|
||||
|
||||
# Running row reports 0 ahead (itself is the head).
|
||||
assert p_running == (0, 3)
|
||||
# Second pending: running is ahead (1) + zero older pendings.
|
||||
assert p_second == (1, 3)
|
||||
# Third pending: running ahead + one older pending.
|
||||
assert p_third == (2, 3)
|
||||
|
||||
|
||||
async def test_queue_position_terminal_returns_zero_zero(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""Finished jobs have no queue position — always (0, 0)."""
|
||||
|
||||
async with session_factory() as session:
|
||||
inserted = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", "qp-term"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
response = ResponseIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="c",
|
||||
request_id="qp-term",
|
||||
)
|
||||
async with session_factory() as session:
|
||||
await jobs_repo.mark_done(session, inserted.job_id, response)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
pos = await jobs_repo.queue_position(session, inserted.job_id)
|
||||
|
||||
assert pos == (0, 0)
|
||||
|
||||
|
||||
async def test_queue_position_unknown_id_returns_zero_zero(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
pos = await jobs_repo.queue_position(session, uuid4())
|
||||
assert pos == (0, 0)
|
||||
|
||||
|
||||
async def test_concurrent_claim_never_double_dispatches(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""Spin a batch of concurrent claimers; every insert is claimed exactly once."""
|
||||
|
||||
async with session_factory() as session:
|
||||
ids = []
|
||||
for i in range(5):
|
||||
job = await jobs_repo.insert_pending(
|
||||
session, _make_request("mass", f"r-{i}"), callback_url=None
|
||||
)
|
||||
ids.append(job.job_id)
|
||||
await session.commit()
|
||||
|
||||
async def claim_one() -> UUID | None:
|
||||
async with session_factory() as session:
|
||||
claimed = await jobs_repo.claim_next_pending(session)
|
||||
await session.commit()
|
||||
return claimed.job_id if claimed else None
|
||||
|
||||
results = await asyncio.gather(*(claim_one() for _ in range(10)))
|
||||
non_null = [r for r in results if r is not None]
|
||||
# Every inserted id appears at most once.
|
||||
assert sorted(non_null) == sorted(ids)
|
||||
|
||||
|
||||
# ---------- list_recent ---------------------------------------------------
|
||||
#
|
||||
# The UI's ``/ui/jobs`` page needs a paginated, filterable view of recent
|
||||
# jobs. We keep the contract intentionally small: list_recent returns
|
||||
# ``(jobs, total)`` — ``total`` is the count after filters but before
|
||||
# limit/offset — so the template can render "Showing N of M".
|
||||
|
||||
|
||||
async def test_list_recent_empty_db(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
jobs, total = await jobs_repo.list_recent(session, limit=50, offset=0)
|
||||
assert jobs == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
async def test_list_recent_orders_newest_first(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
ids: list[UUID] = []
|
||||
for i in range(3):
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", f"lr-{i}"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
ids.append(job.job_id)
|
||||
|
||||
async with session_factory() as session:
|
||||
jobs, total = await jobs_repo.list_recent(session, limit=50, offset=0)
|
||||
|
||||
assert total == 3
|
||||
# Newest first → reverse of insertion order.
|
||||
assert [j.job_id for j in jobs] == list(reversed(ids))
|
||||
|
||||
|
||||
async def test_list_recent_status_single_filter(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
# Two pending, one done.
|
||||
async with session_factory() as session:
|
||||
for i in range(3):
|
||||
await jobs_repo.insert_pending(
|
||||
session, _make_request("c", f"sf-{i}"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
claimed = await jobs_repo.claim_next_pending(session)
|
||||
assert claimed is not None
|
||||
await jobs_repo.mark_done(
|
||||
session,
|
||||
claimed.job_id,
|
||||
ResponseIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="c",
|
||||
request_id=claimed.request_id,
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
done_jobs, done_total = await jobs_repo.list_recent(
|
||||
session, limit=50, offset=0, status="done"
|
||||
)
|
||||
assert done_total == 1
|
||||
assert len(done_jobs) == 1
|
||||
assert done_jobs[0].status == "done"
|
||||
|
||||
async with session_factory() as session:
|
||||
pending_jobs, pending_total = await jobs_repo.list_recent(
|
||||
session, limit=50, offset=0, status="pending"
|
||||
)
|
||||
assert pending_total == 2
|
||||
assert all(j.status == "pending" for j in pending_jobs)
|
||||
|
||||
|
||||
async def test_list_recent_status_iterable_filter(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
# Two pending, one done, one errored.
|
||||
async with session_factory() as session:
|
||||
for i in range(4):
|
||||
await jobs_repo.insert_pending(
|
||||
session, _make_request("c", f"if-{i}"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
a = await jobs_repo.claim_next_pending(session)
|
||||
assert a is not None
|
||||
await jobs_repo.mark_done(
|
||||
session,
|
||||
a.job_id,
|
||||
ResponseIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="c",
|
||||
request_id=a.request_id,
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
async with session_factory() as session:
|
||||
b = await jobs_repo.claim_next_pending(session)
|
||||
assert b is not None
|
||||
await jobs_repo.mark_error(session, b.job_id, ResponseIX(error="boom"))
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
jobs, total = await jobs_repo.list_recent(
|
||||
session, limit=50, offset=0, status=["done", "error"]
|
||||
)
|
||||
assert total == 2
|
||||
assert {j.status for j in jobs} == {"done", "error"}
|
||||
|
||||
|
||||
async def test_list_recent_client_id_filter(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
await jobs_repo.insert_pending(
|
||||
session, _make_request("alpha", "a-1"), callback_url=None
|
||||
)
|
||||
await jobs_repo.insert_pending(
|
||||
session, _make_request("beta", "b-1"), callback_url=None
|
||||
)
|
||||
await jobs_repo.insert_pending(
|
||||
session, _make_request("alpha", "a-2"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
jobs, total = await jobs_repo.list_recent(
|
||||
session, limit=50, offset=0, client_id="alpha"
|
||||
)
|
||||
assert total == 2
|
||||
assert all(j.client_id == "alpha" for j in jobs)
|
||||
|
||||
|
||||
async def test_list_recent_pagination(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
ids: list[UUID] = []
|
||||
for i in range(7):
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.insert_pending(
|
||||
session, _make_request("c", f"pg-{i}"), callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
ids.append(job.job_id)
|
||||
|
||||
async with session_factory() as session:
|
||||
page1, total1 = await jobs_repo.list_recent(
|
||||
session, limit=3, offset=0
|
||||
)
|
||||
assert total1 == 7
|
||||
assert len(page1) == 3
|
||||
# Newest three are the last three inserted.
|
||||
assert [j.job_id for j in page1] == list(reversed(ids[-3:]))
|
||||
|
||||
async with session_factory() as session:
|
||||
page2, total2 = await jobs_repo.list_recent(
|
||||
session, limit=3, offset=3
|
||||
)
|
||||
assert total2 == 7
|
||||
assert len(page2) == 3
|
||||
expected = list(reversed(ids))[3:6]
|
||||
assert [j.job_id for j in page2] == expected
|
||||
|
||||
async with session_factory() as session:
|
||||
page3, total3 = await jobs_repo.list_recent(
|
||||
session, limit=3, offset=6
|
||||
)
|
||||
assert total3 == 7
|
||||
assert len(page3) == 1
|
||||
assert page3[0].job_id == ids[0]
|
||||
|
||||
|
||||
async def test_list_recent_caps_limit(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""limit is capped at 200 — asking for 9999 gets clamped."""
|
||||
|
||||
async with session_factory() as session:
|
||||
jobs, total = await jobs_repo.list_recent(
|
||||
session, limit=9999, offset=0
|
||||
)
|
||||
assert total == 0
|
||||
assert jobs == []
|
||||
|
||||
|
||||
async def test_list_recent_rejects_negative_offset(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
import pytest as _pytest
|
||||
|
||||
with _pytest.raises(ValueError):
|
||||
await jobs_repo.list_recent(session, limit=50, offset=-1)
|
||||
153
tests/integration/test_pg_queue_adapter.py
Normal file
153
tests/integration/test_pg_queue_adapter.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
"""Integration tests for the PgQueueListener + worker integration (Task 3.6).
|
||||
|
||||
Two scenarios:
|
||||
|
||||
1. NOTIFY delivered — worker wakes within ~1 s and picks the row up.
|
||||
2. Missed NOTIFY — the row still gets picked up by the fallback poll.
|
||||
|
||||
Both run a real worker + listener against a live Postgres. We drive them via
|
||||
``asyncio.gather`` + a "until done" watchdog.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from ix.adapters.pg_queue.listener import PgQueueListener, asyncpg_dsn_from_sqlalchemy_url
|
||||
from ix.contracts.request import Context, RequestIX
|
||||
from ix.pipeline.pipeline import Pipeline
|
||||
from ix.pipeline.step import Step
|
||||
from ix.store import jobs_repo
|
||||
from ix.worker.loop import Worker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
class _PassingStep(Step):
|
||||
"""Same minimal fake as test_worker_loop — keeps these suites independent."""
|
||||
|
||||
step_name = "fake_pass"
|
||||
|
||||
async def validate(self, request_ix, response_ix): # type: ignore[no-untyped-def]
|
||||
return True
|
||||
|
||||
async def process(self, request_ix, response_ix): # type: ignore[no-untyped-def]
|
||||
response_ix.use_case = request_ix.use_case
|
||||
return response_ix
|
||||
|
||||
|
||||
def _factory() -> Pipeline:
|
||||
return Pipeline(steps=[_PassingStep()])
|
||||
|
||||
|
||||
async def _wait_for_status(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
job_id,
|
||||
target: str,
|
||||
timeout_s: float,
|
||||
) -> bool:
|
||||
deadline = asyncio.get_event_loop().time() + timeout_s
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.get(session, job_id)
|
||||
if job is not None and job.status == target:
|
||||
return True
|
||||
await asyncio.sleep(0.1)
|
||||
return False
|
||||
|
||||
|
||||
async def test_notify_wakes_worker_within_2s(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
"""Direct INSERT + NOTIFY → worker picks it up fast (not via the poll)."""
|
||||
|
||||
listener = PgQueueListener(dsn=asyncpg_dsn_from_sqlalchemy_url(postgres_url))
|
||||
await listener.start()
|
||||
|
||||
worker = Worker(
|
||||
session_factory=session_factory,
|
||||
pipeline_factory=_factory,
|
||||
# 60 s fallback poll — if we still find the row within 2 s it's
|
||||
# because NOTIFY woke us, not the poll.
|
||||
poll_interval_seconds=60.0,
|
||||
max_running_seconds=3600,
|
||||
wait_for_work=listener.wait_for_work,
|
||||
)
|
||||
stop = asyncio.Event()
|
||||
worker_task = asyncio.create_task(worker.run(stop))
|
||||
|
||||
# Give the worker one tick to reach the sleep_or_wake branch.
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Insert a pending row manually + NOTIFY — simulates a direct-SQL client
|
||||
# like an external batch script.
|
||||
request = RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="pgq",
|
||||
request_id="notify-1",
|
||||
context=Context(texts=["hi"]),
|
||||
)
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.insert_pending(session, request, callback_url=None)
|
||||
await session.commit()
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
text(f"NOTIFY ix_jobs_new, '{job.job_id}'")
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert await _wait_for_status(session_factory, job.job_id, "done", 3.0), (
|
||||
"worker didn't pick up the NOTIFY'd row in time"
|
||||
)
|
||||
|
||||
stop.set()
|
||||
await worker_task
|
||||
await listener.stop()
|
||||
|
||||
|
||||
async def test_missed_notify_falls_back_to_poll(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
"""Row lands without a NOTIFY; fallback poll still picks it up."""
|
||||
|
||||
listener = PgQueueListener(dsn=asyncpg_dsn_from_sqlalchemy_url(postgres_url))
|
||||
await listener.start()
|
||||
|
||||
worker = Worker(
|
||||
session_factory=session_factory,
|
||||
pipeline_factory=_factory,
|
||||
# Short poll so the fallback kicks in quickly — we need the test
|
||||
# to finish in seconds, not the spec's 10 s.
|
||||
poll_interval_seconds=0.5,
|
||||
max_running_seconds=3600,
|
||||
wait_for_work=listener.wait_for_work,
|
||||
)
|
||||
stop = asyncio.Event()
|
||||
worker_task = asyncio.create_task(worker.run(stop))
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Insert without NOTIFY: simulate a buggy writer.
|
||||
request = RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="pgq",
|
||||
request_id="missed-1",
|
||||
context=Context(texts=["hi"]),
|
||||
)
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.insert_pending(session, request, callback_url=None)
|
||||
await session.commit()
|
||||
|
||||
assert await _wait_for_status(session_factory, job.job_id, "done", 5.0), (
|
||||
"fallback poll didn't pick up the row"
|
||||
)
|
||||
|
||||
stop.set()
|
||||
await worker_task
|
||||
await listener.stop()
|
||||
179
tests/integration/test_rest_adapter.py
Normal file
179
tests/integration/test_rest_adapter.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""Integration tests for the FastAPI REST adapter (spec §5).
|
||||
|
||||
Uses ``fastapi.testclient.TestClient`` against a real DB. Ollama / OCR probes
|
||||
are stubbed via the DI hooks the routes expose for testing — in Chunk 4 the
|
||||
production probes swap in.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from ix.adapters.rest.routes import Probes, get_probes, get_session_factory_dep
|
||||
from ix.app import create_app
|
||||
|
||||
|
||||
def _factory_for_url(postgres_url: str): # type: ignore[no-untyped-def]
|
||||
"""Build a TestClient-compatible session factory.
|
||||
|
||||
TestClient runs the ASGI app on its own dedicated event loop (the one it
|
||||
creates in its sync wrapper), distinct from the per-test loop
|
||||
pytest-asyncio gives direct tests. Session factories must therefore be
|
||||
constructed from an engine that was itself created on that inner loop.
|
||||
We do this lazily: each dependency resolution creates a fresh engine +
|
||||
factory on the current running loop, which is the TestClient's loop at
|
||||
route-invocation time. Engine reuse would drag the cross-loop futures
|
||||
that asyncpg hates back in.
|
||||
"""
|
||||
|
||||
def _factory(): # type: ignore[no-untyped-def]
|
||||
eng = create_async_engine(postgres_url, pool_pre_ping=True)
|
||||
return async_sessionmaker(eng, expire_on_commit=False)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(postgres_url: str) -> Iterator[TestClient]:
|
||||
"""Spin up the FastAPI app wired to the test DB + stub probes."""
|
||||
|
||||
app_obj = create_app(spawn_worker=False)
|
||||
app_obj.dependency_overrides[get_session_factory_dep] = _factory_for_url(
|
||||
postgres_url
|
||||
)
|
||||
app_obj.dependency_overrides[get_probes] = lambda: Probes(
|
||||
ollama=lambda: "ok",
|
||||
ocr=lambda: "ok",
|
||||
)
|
||||
with TestClient(app_obj) as client:
|
||||
yield client
|
||||
|
||||
|
||||
def _valid_request_body(client_id: str = "mammon", request_id: str = "r-1") -> dict:
|
||||
return {
|
||||
"use_case": "bank_statement_header",
|
||||
"ix_client_id": client_id,
|
||||
"request_id": request_id,
|
||||
"context": {"texts": ["hello world"]},
|
||||
}
|
||||
|
||||
|
||||
def test_post_jobs_creates_pending(app: TestClient) -> None:
|
||||
resp = app.post("/jobs", json=_valid_request_body())
|
||||
assert resp.status_code == 201, resp.text
|
||||
body = resp.json()
|
||||
assert body["status"] == "pending"
|
||||
assert len(body["ix_id"]) == 16
|
||||
assert body["job_id"]
|
||||
|
||||
|
||||
def test_post_jobs_idempotent_returns_200(app: TestClient) -> None:
|
||||
first = app.post("/jobs", json=_valid_request_body("m", "dup"))
|
||||
assert first.status_code == 201
|
||||
first_body = first.json()
|
||||
|
||||
second = app.post("/jobs", json=_valid_request_body("m", "dup"))
|
||||
assert second.status_code == 200
|
||||
second_body = second.json()
|
||||
assert second_body["job_id"] == first_body["job_id"]
|
||||
assert second_body["ix_id"] == first_body["ix_id"]
|
||||
|
||||
|
||||
def test_get_job_by_id(app: TestClient) -> None:
|
||||
created = app.post("/jobs", json=_valid_request_body()).json()
|
||||
resp = app.get(f"/jobs/{created['job_id']}")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["job_id"] == created["job_id"]
|
||||
assert body["request"]["use_case"] == "bank_statement_header"
|
||||
assert body["status"] == "pending"
|
||||
|
||||
|
||||
def test_get_job_404(app: TestClient) -> None:
|
||||
resp = app.get(f"/jobs/{uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_get_by_correlation_query(app: TestClient) -> None:
|
||||
created = app.post("/jobs", json=_valid_request_body("mammon", "corr-1")).json()
|
||||
resp = app.get("/jobs", params={"client_id": "mammon", "request_id": "corr-1"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["job_id"] == created["job_id"]
|
||||
|
||||
missing = app.get("/jobs", params={"client_id": "mammon", "request_id": "nope"})
|
||||
assert missing.status_code == 404
|
||||
|
||||
|
||||
def test_healthz_all_ok(app: TestClient) -> None:
|
||||
resp = app.get("/healthz")
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["postgres"] == "ok"
|
||||
assert body["ollama"] == "ok"
|
||||
assert body["ocr"] == "ok"
|
||||
|
||||
|
||||
def test_healthz_503_on_postgres_fail(postgres_url: str) -> None:
|
||||
"""Broken postgres probe → 503. Ollama/OCR still surface in the body."""
|
||||
|
||||
app_obj = create_app(spawn_worker=False)
|
||||
|
||||
def _bad_factory(): # type: ignore[no-untyped-def]
|
||||
def _raise(): # type: ignore[no-untyped-def]
|
||||
raise RuntimeError("db down")
|
||||
|
||||
return _raise
|
||||
|
||||
app_obj.dependency_overrides[get_session_factory_dep] = _bad_factory
|
||||
app_obj.dependency_overrides[get_probes] = lambda: Probes(
|
||||
ollama=lambda: "ok", ocr=lambda: "ok"
|
||||
)
|
||||
|
||||
with TestClient(app_obj) as client:
|
||||
resp = client.get("/healthz")
|
||||
assert resp.status_code == 503
|
||||
body = resp.json()
|
||||
assert body["postgres"] == "fail"
|
||||
|
||||
|
||||
def test_healthz_degraded_ollama_is_503(postgres_url: str) -> None:
|
||||
"""Per spec §5: degraded flips HTTP to 503 (only all-ok yields 200)."""
|
||||
|
||||
app_obj = create_app(spawn_worker=False)
|
||||
app_obj.dependency_overrides[get_session_factory_dep] = _factory_for_url(
|
||||
postgres_url
|
||||
)
|
||||
app_obj.dependency_overrides[get_probes] = lambda: Probes(
|
||||
ollama=lambda: "degraded", ocr=lambda: "ok"
|
||||
)
|
||||
|
||||
with TestClient(app_obj) as client:
|
||||
resp = client.get("/healthz")
|
||||
assert resp.status_code == 503
|
||||
assert resp.json()["ollama"] == "degraded"
|
||||
|
||||
|
||||
def test_metrics_shape(app: TestClient) -> None:
|
||||
# Submit a couple of pending jobs to populate counters.
|
||||
app.post("/jobs", json=_valid_request_body("mm", "a"))
|
||||
app.post("/jobs", json=_valid_request_body("mm", "b"))
|
||||
|
||||
resp = app.get("/metrics")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
for key in (
|
||||
"jobs_pending",
|
||||
"jobs_running",
|
||||
"jobs_done_24h",
|
||||
"jobs_error_24h",
|
||||
"by_use_case_seconds",
|
||||
):
|
||||
assert key in body
|
||||
assert body["jobs_pending"] == 2
|
||||
assert body["jobs_running"] == 0
|
||||
assert isinstance(body["by_use_case_seconds"], dict)
|
||||
792
tests/integration/test_ui_routes.py
Normal file
792
tests/integration/test_ui_routes.py
Normal file
|
|
@ -0,0 +1,792 @@
|
|||
"""Integration tests for the `/ui` router (spec §PR 2).
|
||||
|
||||
Covers the full round-trip through `POST /ui/jobs` — the handler parses
|
||||
multipart form data into a `RequestIX` and hands it to
|
||||
`ix.store.jobs_repo.insert_pending`, the same entry point the REST adapter
|
||||
uses. Tests assert the job row exists with the right client/request ids and
|
||||
that custom-use-case forms produce a `use_case_inline` block in the stored
|
||||
request JSON.
|
||||
|
||||
The DB-touching tests depend on the shared integration conftest which
|
||||
spins up migrations against the configured Postgres; the pure-template
|
||||
tests (`GET /ui` and the fragment renderer) still need a factory but
|
||||
won't actually query — they're cheap.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from ix.adapters.rest.routes import Probes, get_probes, get_session_factory_dep
|
||||
from ix.app import create_app
|
||||
from ix.store.models import IxJob
|
||||
|
||||
FIXTURE_DIR = Path(__file__).resolve().parents[1] / "fixtures"
|
||||
FIXTURE_PDF = FIXTURE_DIR / "synthetic_giro.pdf"
|
||||
|
||||
|
||||
def _factory_for_url(postgres_url: str): # type: ignore[no-untyped-def]
|
||||
def _factory(): # type: ignore[no-untyped-def]
|
||||
eng = create_async_engine(postgres_url, pool_pre_ping=True)
|
||||
return async_sessionmaker(eng, expire_on_commit=False)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(postgres_url: str) -> Iterator[TestClient]:
|
||||
app_obj = create_app(spawn_worker=False)
|
||||
app_obj.dependency_overrides[get_session_factory_dep] = _factory_for_url(
|
||||
postgres_url
|
||||
)
|
||||
app_obj.dependency_overrides[get_probes] = lambda: Probes(
|
||||
ollama=lambda: "ok", ocr=lambda: "ok"
|
||||
)
|
||||
with TestClient(app_obj) as client:
|
||||
yield client
|
||||
|
||||
|
||||
class TestIndexPage:
|
||||
def test_index_returns_html(self, app: TestClient) -> None:
|
||||
resp = app.get("/ui")
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers["content-type"]
|
||||
body = resp.text
|
||||
# Dropdown prefilled with the registered use case.
|
||||
assert "bank_statement_header" in body
|
||||
# Marker for the submission form.
|
||||
assert '<form' in body
|
||||
|
||||
def test_static_mount_is_reachable(self, app: TestClient) -> None:
|
||||
# StaticFiles returns 404 for the keepfile; the mount itself must
|
||||
# exist so asset URLs resolve. We probe the directory root instead.
|
||||
resp = app.get("/ui/static/.gitkeep")
|
||||
# .gitkeep exists in the repo — expect 200 (or at minimum not a 404
|
||||
# due to missing mount). A 405/403 would also indicate the mount is
|
||||
# wired; we assert the response is *not* a 404 from a missing route.
|
||||
assert resp.status_code != 404
|
||||
|
||||
|
||||
class TestSubmitJobRegistered:
|
||||
def test_post_registered_use_case_creates_row(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
request_id = f"ui-reg-{uuid4().hex[:8]}"
|
||||
with FIXTURE_PDF.open("rb") as fh:
|
||||
resp = app.post(
|
||||
"/ui/jobs",
|
||||
data={
|
||||
"use_case_mode": "registered",
|
||||
"use_case_name": "bank_statement_header",
|
||||
"ix_client_id": "ui-test",
|
||||
"request_id": request_id,
|
||||
"texts": "",
|
||||
"use_ocr": "on",
|
||||
"include_provenance": "on",
|
||||
"max_sources_per_field": "10",
|
||||
},
|
||||
files={"pdf": ("sample.pdf", fh, "application/pdf")},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert resp.status_code in (200, 303), resp.text
|
||||
|
||||
# Assert the row exists in the DB.
|
||||
job_row = _find_job(postgres_url, "ui-test", request_id)
|
||||
assert job_row is not None
|
||||
assert job_row.status == "pending"
|
||||
assert job_row.request["use_case"] == "bank_statement_header"
|
||||
# Context.files must reference a local file:// path.
|
||||
files = job_row.request["context"]["files"]
|
||||
assert len(files) == 1
|
||||
entry = files[0]
|
||||
url = entry if isinstance(entry, str) else entry["url"]
|
||||
assert url.startswith("file://")
|
||||
|
||||
def test_htmx_submit_uses_hx_redirect_header(
|
||||
self,
|
||||
app: TestClient,
|
||||
) -> None:
|
||||
request_id = f"ui-htmx-{uuid4().hex[:8]}"
|
||||
with FIXTURE_PDF.open("rb") as fh:
|
||||
resp = app.post(
|
||||
"/ui/jobs",
|
||||
data={
|
||||
"use_case_mode": "registered",
|
||||
"use_case_name": "bank_statement_header",
|
||||
"ix_client_id": "ui-test",
|
||||
"request_id": request_id,
|
||||
},
|
||||
files={"pdf": ("sample.pdf", fh, "application/pdf")},
|
||||
headers={"HX-Request": "true"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "HX-Redirect" in resp.headers
|
||||
|
||||
|
||||
class TestSubmitJobCustom:
|
||||
def test_post_custom_use_case_stores_inline(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
request_id = f"ui-cust-{uuid4().hex[:8]}"
|
||||
fields_json = json.dumps(
|
||||
[
|
||||
{"name": "vendor", "type": "str", "required": True},
|
||||
{"name": "total", "type": "decimal"},
|
||||
]
|
||||
)
|
||||
with FIXTURE_PDF.open("rb") as fh:
|
||||
resp = app.post(
|
||||
"/ui/jobs",
|
||||
data={
|
||||
"use_case_mode": "custom",
|
||||
"use_case_name": "invoice_adhoc",
|
||||
"ix_client_id": "ui-test",
|
||||
"request_id": request_id,
|
||||
"system_prompt": "Extract vendor and total.",
|
||||
"default_model": "qwen3:14b",
|
||||
"fields_json": fields_json,
|
||||
},
|
||||
files={"pdf": ("sample.pdf", fh, "application/pdf")},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert resp.status_code in (200, 303), resp.text
|
||||
job_row = _find_job(postgres_url, "ui-test", request_id)
|
||||
assert job_row is not None
|
||||
stored = job_row.request["use_case_inline"]
|
||||
assert stored is not None
|
||||
assert stored["use_case_name"] == "invoice_adhoc"
|
||||
assert stored["system_prompt"] == "Extract vendor and total."
|
||||
names = [f["name"] for f in stored["fields"]]
|
||||
assert names == ["vendor", "total"]
|
||||
|
||||
def test_post_malformed_fields_json_rejected(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
request_id = f"ui-bad-{uuid4().hex[:8]}"
|
||||
with FIXTURE_PDF.open("rb") as fh:
|
||||
resp = app.post(
|
||||
"/ui/jobs",
|
||||
data={
|
||||
"use_case_mode": "custom",
|
||||
"use_case_name": "adhoc_bad",
|
||||
"ix_client_id": "ui-test",
|
||||
"request_id": request_id,
|
||||
"system_prompt": "p",
|
||||
"fields_json": "this is not json",
|
||||
},
|
||||
files={"pdf": ("sample.pdf", fh, "application/pdf")},
|
||||
follow_redirects=False,
|
||||
)
|
||||
# Either re-rendered form (422 / 200 with error) — what matters is
|
||||
# that no row was inserted.
|
||||
assert resp.status_code in (200, 400, 422)
|
||||
job_row = _find_job(postgres_url, "ui-test", request_id)
|
||||
assert job_row is None
|
||||
# A helpful error should appear somewhere in the body.
|
||||
assert (
|
||||
"error" in resp.text.lower()
|
||||
or "invalid" in resp.text.lower()
|
||||
or "json" in resp.text.lower()
|
||||
)
|
||||
|
||||
|
||||
class TestDisplayName:
|
||||
def test_post_persists_display_name_in_file_ref(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
"""The client-provided upload filename lands in FileRef.display_name."""
|
||||
|
||||
request_id = f"ui-name-{uuid4().hex[:8]}"
|
||||
with FIXTURE_PDF.open("rb") as fh:
|
||||
resp = app.post(
|
||||
"/ui/jobs",
|
||||
data={
|
||||
"use_case_mode": "registered",
|
||||
"use_case_name": "bank_statement_header",
|
||||
"ix_client_id": "ui-test",
|
||||
"request_id": request_id,
|
||||
},
|
||||
files={
|
||||
"pdf": ("my statement.pdf", fh, "application/pdf")
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert resp.status_code in (200, 303), resp.text
|
||||
|
||||
job_row = _find_job(postgres_url, "ui-test", request_id)
|
||||
assert job_row is not None
|
||||
entry = job_row.request["context"]["files"][0]
|
||||
assert isinstance(entry, dict)
|
||||
assert entry["display_name"] == "my statement.pdf"
|
||||
|
||||
|
||||
class TestFragment:
|
||||
def test_fragment_pending_has_trigger(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
request_id = f"ui-frag-p-{uuid4().hex[:8]}"
|
||||
with FIXTURE_PDF.open("rb") as fh:
|
||||
app.post(
|
||||
"/ui/jobs",
|
||||
data={
|
||||
"use_case_mode": "registered",
|
||||
"use_case_name": "bank_statement_header",
|
||||
"ix_client_id": "ui-test",
|
||||
"request_id": request_id,
|
||||
},
|
||||
files={"pdf": ("sample.pdf", fh, "application/pdf")},
|
||||
follow_redirects=False,
|
||||
)
|
||||
job_row = _find_job(postgres_url, "ui-test", request_id)
|
||||
assert job_row is not None
|
||||
|
||||
resp = app.get(f"/ui/jobs/{job_row.job_id}/fragment")
|
||||
assert resp.status_code == 200
|
||||
body = resp.text
|
||||
# Pending → auto-refresh every 2s.
|
||||
assert "hx-trigger" in body
|
||||
assert "2s" in body
|
||||
assert "pending" in body.lower() or "running" in body.lower()
|
||||
# New queue-awareness copy.
|
||||
assert "Queue position" in body or "About to start" in body
|
||||
|
||||
def test_fragment_pending_shows_filename(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
request_id = f"ui-frag-pf-{uuid4().hex[:8]}"
|
||||
with FIXTURE_PDF.open("rb") as fh:
|
||||
app.post(
|
||||
"/ui/jobs",
|
||||
data={
|
||||
"use_case_mode": "registered",
|
||||
"use_case_name": "bank_statement_header",
|
||||
"ix_client_id": "ui-test",
|
||||
"request_id": request_id,
|
||||
},
|
||||
files={
|
||||
"pdf": (
|
||||
"client-side-name.pdf",
|
||||
fh,
|
||||
"application/pdf",
|
||||
)
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
job_row = _find_job(postgres_url, "ui-test", request_id)
|
||||
assert job_row is not None
|
||||
resp = app.get(f"/ui/jobs/{job_row.job_id}/fragment")
|
||||
assert resp.status_code == 200
|
||||
assert "client-side-name.pdf" in resp.text
|
||||
|
||||
def test_fragment_running_shows_elapsed(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
"""After flipping a row to running with a backdated started_at, the
|
||||
fragment renders a ``Running for MM:SS`` line."""
|
||||
|
||||
request_id = f"ui-frag-r-{uuid4().hex[:8]}"
|
||||
with FIXTURE_PDF.open("rb") as fh:
|
||||
app.post(
|
||||
"/ui/jobs",
|
||||
data={
|
||||
"use_case_mode": "registered",
|
||||
"use_case_name": "bank_statement_header",
|
||||
"ix_client_id": "ui-test",
|
||||
"request_id": request_id,
|
||||
},
|
||||
files={"pdf": ("sample.pdf", fh, "application/pdf")},
|
||||
follow_redirects=False,
|
||||
)
|
||||
job_row = _find_job(postgres_url, "ui-test", request_id)
|
||||
assert job_row is not None
|
||||
|
||||
_force_running(postgres_url, job_row.job_id)
|
||||
|
||||
resp = app.get(f"/ui/jobs/{job_row.job_id}/fragment")
|
||||
assert resp.status_code == 200
|
||||
body = resp.text
|
||||
assert "Running for" in body
|
||||
# MM:SS; our backdate is ~10s so expect 00:1? or higher.
|
||||
import re
|
||||
|
||||
assert re.search(r"\d{2}:\d{2}", body), body
|
||||
|
||||
def test_fragment_backward_compat_no_display_name(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
"""Older rows (stored before display_name existed) must still render."""
|
||||
|
||||
from ix.contracts.request import Context, FileRef, RequestIX
|
||||
|
||||
legacy_req = RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="ui-test",
|
||||
request_id=f"ui-legacy-{uuid4().hex[:8]}",
|
||||
context=Context(
|
||||
files=[
|
||||
FileRef(url="file:///tmp/ix/ui/legacy.pdf")
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
from ix.store import jobs_repo as _repo
|
||||
|
||||
async def _insert() -> UUID:
|
||||
eng = create_async_engine(postgres_url)
|
||||
sf = async_sessionmaker(eng, expire_on_commit=False)
|
||||
try:
|
||||
async with sf() as session:
|
||||
job = await _repo.insert_pending(
|
||||
session, legacy_req, callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
return job.job_id
|
||||
finally:
|
||||
await eng.dispose()
|
||||
|
||||
job_id = asyncio.run(_insert())
|
||||
resp = app.get(f"/ui/jobs/{job_id}/fragment")
|
||||
assert resp.status_code == 200
|
||||
body = resp.text
|
||||
# Must not crash; must include the fallback basename from the URL.
|
||||
assert "legacy.pdf" in body
|
||||
|
||||
def test_fragment_done_shows_pretty_json(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
request_id = f"ui-frag-d-{uuid4().hex[:8]}"
|
||||
with FIXTURE_PDF.open("rb") as fh:
|
||||
app.post(
|
||||
"/ui/jobs",
|
||||
data={
|
||||
"use_case_mode": "registered",
|
||||
"use_case_name": "bank_statement_header",
|
||||
"ix_client_id": "ui-test",
|
||||
"request_id": request_id,
|
||||
},
|
||||
files={
|
||||
"pdf": (
|
||||
"my-done-doc.pdf",
|
||||
fh,
|
||||
"application/pdf",
|
||||
)
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
job_row = _find_job(postgres_url, "ui-test", request_id)
|
||||
assert job_row is not None
|
||||
|
||||
# Hand-tick the row to done with a fake response.
|
||||
_force_done(
|
||||
postgres_url,
|
||||
job_row.job_id,
|
||||
response_body={
|
||||
"use_case": "bank_statement_header",
|
||||
"ix_result": {"result": {"bank_name": "UBS AG", "currency": "CHF"}},
|
||||
},
|
||||
)
|
||||
|
||||
resp = app.get(f"/ui/jobs/{job_row.job_id}/fragment")
|
||||
assert resp.status_code == 200
|
||||
body = resp.text
|
||||
# Terminal → no auto-refresh.
|
||||
assert "every 2s" not in body and "every 2s" not in body
|
||||
# JSON present.
|
||||
assert "UBS AG" in body
|
||||
assert "CHF" in body
|
||||
# Filename surfaced on the done fragment.
|
||||
assert "my-done-doc.pdf" in body
|
||||
|
||||
|
||||
class TestJobsListPage:
|
||||
"""Tests for the ``GET /ui/jobs`` listing page (feat/ui-jobs-list)."""
|
||||
|
||||
def _submit(
|
||||
self,
|
||||
app: TestClient,
|
||||
client_id: str,
|
||||
request_id: str,
|
||||
filename: str = "sample.pdf",
|
||||
) -> None:
|
||||
with FIXTURE_PDF.open("rb") as fh:
|
||||
app.post(
|
||||
"/ui/jobs",
|
||||
data={
|
||||
"use_case_mode": "registered",
|
||||
"use_case_name": "bank_statement_header",
|
||||
"ix_client_id": client_id,
|
||||
"request_id": request_id,
|
||||
},
|
||||
files={"pdf": (filename, fh, "application/pdf")},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
def test_jobs_list_returns_html(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
for i in range(3):
|
||||
self._submit(
|
||||
app,
|
||||
"ui-list",
|
||||
f"lp-{uuid4().hex[:6]}-{i}",
|
||||
filename=f"doc-{i}.pdf",
|
||||
)
|
||||
|
||||
resp = app.get("/ui/jobs")
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers["content-type"]
|
||||
body = resp.text
|
||||
# Breadcrumb / header shows "Jobs".
|
||||
assert "Jobs" in body
|
||||
# display_name surfaces for each row.
|
||||
for i in range(3):
|
||||
assert f"doc-{i}.pdf" in body
|
||||
# Showing N of M counter present.
|
||||
assert "Showing" in body
|
||||
assert "of" in body
|
||||
|
||||
def test_jobs_list_links_to_job_detail(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
rid = f"lp-link-{uuid4().hex[:6]}"
|
||||
self._submit(app, "ui-list", rid)
|
||||
row = _find_job(postgres_url, "ui-list", rid)
|
||||
assert row is not None
|
||||
resp = app.get("/ui/jobs")
|
||||
assert resp.status_code == 200
|
||||
assert f"/ui/jobs/{row.job_id}" in resp.text
|
||||
|
||||
def test_jobs_list_status_filter_single(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
# Create two jobs, flip one to done.
|
||||
rid_pending = f"lp-p-{uuid4().hex[:6]}"
|
||||
rid_done = f"lp-d-{uuid4().hex[:6]}"
|
||||
self._submit(app, "ui-filt", rid_pending, filename="pending-doc.pdf")
|
||||
self._submit(app, "ui-filt", rid_done, filename="done-doc.pdf")
|
||||
done_row = _find_job(postgres_url, "ui-filt", rid_done)
|
||||
assert done_row is not None
|
||||
_force_done(
|
||||
postgres_url,
|
||||
done_row.job_id,
|
||||
response_body={"use_case": "bank_statement_header"},
|
||||
)
|
||||
|
||||
# ?status=done → only done row shown.
|
||||
resp = app.get("/ui/jobs?status=done")
|
||||
assert resp.status_code == 200
|
||||
assert "done-doc.pdf" in resp.text
|
||||
assert "pending-doc.pdf" not in resp.text
|
||||
|
||||
def test_jobs_list_status_filter_multi(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
rid_p = f"lp-mp-{uuid4().hex[:6]}"
|
||||
rid_d = f"lp-md-{uuid4().hex[:6]}"
|
||||
rid_e = f"lp-me-{uuid4().hex[:6]}"
|
||||
self._submit(app, "ui-multi", rid_p, filename="pending-m.pdf")
|
||||
self._submit(app, "ui-multi", rid_d, filename="done-m.pdf")
|
||||
self._submit(app, "ui-multi", rid_e, filename="error-m.pdf")
|
||||
|
||||
done_row = _find_job(postgres_url, "ui-multi", rid_d)
|
||||
err_row = _find_job(postgres_url, "ui-multi", rid_e)
|
||||
assert done_row is not None and err_row is not None
|
||||
_force_done(
|
||||
postgres_url,
|
||||
done_row.job_id,
|
||||
response_body={"use_case": "bank_statement_header"},
|
||||
)
|
||||
_force_error(postgres_url, err_row.job_id)
|
||||
|
||||
resp = app.get("/ui/jobs?status=done&status=error")
|
||||
assert resp.status_code == 200
|
||||
body = resp.text
|
||||
assert "done-m.pdf" in body
|
||||
assert "error-m.pdf" in body
|
||||
assert "pending-m.pdf" not in body
|
||||
|
||||
def test_jobs_list_client_id_filter(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
rid_a = f"lp-a-{uuid4().hex[:6]}"
|
||||
rid_b = f"lp-b-{uuid4().hex[:6]}"
|
||||
self._submit(app, "client-alpha", rid_a, filename="alpha.pdf")
|
||||
self._submit(app, "client-beta", rid_b, filename="beta.pdf")
|
||||
|
||||
resp = app.get("/ui/jobs?client_id=client-alpha")
|
||||
assert resp.status_code == 200
|
||||
body = resp.text
|
||||
assert "alpha.pdf" in body
|
||||
assert "beta.pdf" not in body
|
||||
|
||||
def test_jobs_list_pagination(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
rids = []
|
||||
for i in range(7):
|
||||
rid = f"lp-pg-{uuid4().hex[:6]}-{i}"
|
||||
rids.append(rid)
|
||||
self._submit(app, "ui-pg", rid, filename=f"pg-{i}.pdf")
|
||||
|
||||
resp_p1 = app.get("/ui/jobs?limit=5&offset=0&client_id=ui-pg")
|
||||
assert resp_p1.status_code == 200
|
||||
body_p1 = resp_p1.text
|
||||
# Newest-first: last 5 uploaded are pg-6..pg-2.
|
||||
for i in (2, 3, 4, 5, 6):
|
||||
assert f"pg-{i}.pdf" in body_p1
|
||||
assert "pg-1.pdf" not in body_p1
|
||||
assert "pg-0.pdf" not in body_p1
|
||||
|
||||
resp_p2 = app.get("/ui/jobs?limit=5&offset=5&client_id=ui-pg")
|
||||
assert resp_p2.status_code == 200
|
||||
body_p2 = resp_p2.text
|
||||
assert "pg-1.pdf" in body_p2
|
||||
assert "pg-0.pdf" in body_p2
|
||||
# Showing 2 of 7 on page 2.
|
||||
assert "of 7" in body_p2
|
||||
|
||||
def test_jobs_list_missing_display_name_falls_back_to_basename(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
"""Legacy rows without display_name must still render via basename."""
|
||||
|
||||
from ix.contracts.request import Context, FileRef, RequestIX
|
||||
|
||||
legacy_req = RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="ui-legacy",
|
||||
request_id=f"lp-legacy-{uuid4().hex[:6]}",
|
||||
context=Context(
|
||||
files=[FileRef(url="file:///tmp/ix/ui/listing-legacy.pdf")]
|
||||
),
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
from ix.store import jobs_repo as _repo
|
||||
|
||||
async def _insert() -> UUID:
|
||||
eng = create_async_engine(postgres_url)
|
||||
sf = async_sessionmaker(eng, expire_on_commit=False)
|
||||
try:
|
||||
async with sf() as session:
|
||||
job = await _repo.insert_pending(
|
||||
session, legacy_req, callback_url=None
|
||||
)
|
||||
await session.commit()
|
||||
return job.job_id
|
||||
finally:
|
||||
await eng.dispose()
|
||||
|
||||
asyncio.run(_insert())
|
||||
|
||||
resp = app.get("/ui/jobs?client_id=ui-legacy")
|
||||
assert resp.status_code == 200
|
||||
assert "listing-legacy.pdf" in resp.text
|
||||
|
||||
def test_jobs_list_header_link_from_index(
|
||||
self,
|
||||
app: TestClient,
|
||||
) -> None:
|
||||
resp = app.get("/ui")
|
||||
assert resp.status_code == 200
|
||||
assert 'href="/ui/jobs"' in resp.text
|
||||
|
||||
def test_jobs_list_header_link_from_detail(
|
||||
self,
|
||||
app: TestClient,
|
||||
postgres_url: str,
|
||||
) -> None:
|
||||
rid = f"lp-hd-{uuid4().hex[:6]}"
|
||||
self._submit(app, "ui-hd", rid)
|
||||
row = _find_job(postgres_url, "ui-hd", rid)
|
||||
assert row is not None
|
||||
resp = app.get(f"/ui/jobs/{row.job_id}")
|
||||
assert resp.status_code == 200
|
||||
assert 'href="/ui/jobs"' in resp.text
|
||||
|
||||
|
||||
def _force_error(
|
||||
postgres_url: str,
|
||||
job_id, # type: ignore[no-untyped-def]
|
||||
) -> None:
|
||||
"""Flip a pending/running job to ``error`` with a canned error body."""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
async def _go(): # type: ignore[no-untyped-def]
|
||||
eng = create_async_engine(postgres_url)
|
||||
try:
|
||||
async with eng.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"UPDATE ix_jobs SET status='error', "
|
||||
"response=CAST(:resp AS JSONB), finished_at=:now "
|
||||
"WHERE job_id=:jid"
|
||||
),
|
||||
{
|
||||
"resp": json.dumps({"error": "IX_002_000: forced"}),
|
||||
"now": datetime.now(UTC),
|
||||
"jid": str(job_id),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await eng.dispose()
|
||||
|
||||
asyncio.run(_go())
|
||||
|
||||
|
||||
def _find_job(postgres_url: str, client_id: str, request_id: str): # type: ignore[no-untyped-def]
|
||||
"""Look up an ``ix_jobs`` row via the async engine, wrapping the coroutine
|
||||
for test convenience."""
|
||||
|
||||
import asyncio
|
||||
import json as _json
|
||||
|
||||
async def _go(): # type: ignore[no-untyped-def]
|
||||
eng = create_async_engine(postgres_url)
|
||||
sf = async_sessionmaker(eng, expire_on_commit=False)
|
||||
try:
|
||||
async with sf() as session:
|
||||
r = await session.scalar(
|
||||
select(IxJob).where(
|
||||
IxJob.client_id == client_id,
|
||||
IxJob.request_id == request_id,
|
||||
)
|
||||
)
|
||||
if r is None:
|
||||
return None
|
||||
|
||||
class _JobRow:
|
||||
pass
|
||||
|
||||
out = _JobRow()
|
||||
out.job_id = r.job_id
|
||||
out.client_id = r.client_id
|
||||
out.request_id = r.request_id
|
||||
out.status = r.status
|
||||
if isinstance(r.request, str):
|
||||
out.request = _json.loads(r.request)
|
||||
else:
|
||||
out.request = r.request
|
||||
return out
|
||||
finally:
|
||||
await eng.dispose()
|
||||
|
||||
return asyncio.run(_go())
|
||||
|
||||
|
||||
def _force_done(
|
||||
postgres_url: str,
|
||||
job_id, # type: ignore[no-untyped-def]
|
||||
response_body: dict,
|
||||
) -> None:
|
||||
"""Flip a pending job to ``done`` with the given response payload."""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
async def _go(): # type: ignore[no-untyped-def]
|
||||
eng = create_async_engine(postgres_url)
|
||||
try:
|
||||
async with eng.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"UPDATE ix_jobs SET status='done', "
|
||||
"response=CAST(:resp AS JSONB), finished_at=:now "
|
||||
"WHERE job_id=:jid"
|
||||
),
|
||||
{
|
||||
"resp": json.dumps(response_body),
|
||||
"now": datetime.now(UTC),
|
||||
"jid": str(job_id),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await eng.dispose()
|
||||
|
||||
asyncio.run(_go())
|
||||
|
||||
|
||||
def _force_running(
|
||||
postgres_url: str,
|
||||
job_id, # type: ignore[no-untyped-def]
|
||||
seconds_ago: int = 10,
|
||||
) -> None:
|
||||
"""Flip a pending job to ``running`` with a backdated ``started_at``.
|
||||
|
||||
The fragment renders "Running for MM:SS" which needs a ``started_at`` in
|
||||
the past; 10s is enough to produce a deterministic non-zero MM:SS.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
async def _go(): # type: ignore[no-untyped-def]
|
||||
eng = create_async_engine(postgres_url)
|
||||
try:
|
||||
async with eng.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"UPDATE ix_jobs SET status='running', started_at=:t "
|
||||
"WHERE job_id=:jid"
|
||||
),
|
||||
{
|
||||
"t": datetime.now(UTC) - timedelta(seconds=seconds_ago),
|
||||
"jid": str(job_id),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await eng.dispose()
|
||||
|
||||
asyncio.run(_go())
|
||||
325
tests/integration/test_worker_loop.py
Normal file
325
tests/integration/test_worker_loop.py
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
"""Integration tests for the worker loop (Task 3.5).
|
||||
|
||||
We spin up a real worker with a fake pipeline factory and verify the lifecycle
|
||||
transitions against a live DB. Callback delivery is exercised via
|
||||
``pytest-httpx`` — callers' webhook endpoints are mocked, not run.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from pytest_httpx import HTTPXMock
|
||||
|
||||
from ix.contracts.request import Context, RequestIX
|
||||
from ix.contracts.response import ResponseIX
|
||||
from ix.pipeline.pipeline import Pipeline
|
||||
from ix.pipeline.step import Step
|
||||
from ix.store import jobs_repo
|
||||
from ix.worker.loop import Worker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
class _PassingStep(Step):
|
||||
"""Minimal fake step that writes a sentinel field on the response."""
|
||||
|
||||
step_name = "fake_pass"
|
||||
|
||||
async def validate(self, request_ix, response_ix): # type: ignore[no-untyped-def]
|
||||
return True
|
||||
|
||||
async def process(self, request_ix, response_ix): # type: ignore[no-untyped-def]
|
||||
response_ix.use_case = request_ix.use_case
|
||||
response_ix.ix_client_id = request_ix.ix_client_id
|
||||
response_ix.request_id = request_ix.request_id
|
||||
response_ix.ix_id = request_ix.ix_id
|
||||
return response_ix
|
||||
|
||||
|
||||
class _RaisingStep(Step):
|
||||
"""Fake step that raises a non-IX exception to exercise the worker's
|
||||
belt-and-braces error path."""
|
||||
|
||||
step_name = "fake_raise"
|
||||
|
||||
async def validate(self, request_ix, response_ix): # type: ignore[no-untyped-def]
|
||||
return True
|
||||
|
||||
async def process(self, request_ix, response_ix): # type: ignore[no-untyped-def]
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
def _ok_factory() -> Pipeline:
|
||||
return Pipeline(steps=[_PassingStep()])
|
||||
|
||||
|
||||
def _bad_factory() -> Pipeline:
|
||||
return Pipeline(steps=[_RaisingStep()])
|
||||
|
||||
|
||||
async def _insert_pending(session_factory, **kwargs): # type: ignore[no-untyped-def]
|
||||
request = RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id=kwargs.get("client", "test"),
|
||||
request_id=kwargs.get("rid", "r-1"),
|
||||
context=Context(texts=["hi"]),
|
||||
)
|
||||
async with session_factory() as session:
|
||||
job = await jobs_repo.insert_pending(
|
||||
session, request, callback_url=kwargs.get("cb")
|
||||
)
|
||||
await session.commit()
|
||||
return job
|
||||
|
||||
|
||||
async def test_worker_runs_one_job_to_done(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
job = await _insert_pending(session_factory)
|
||||
|
||||
worker = Worker(
|
||||
session_factory=session_factory,
|
||||
pipeline_factory=_ok_factory,
|
||||
poll_interval_seconds=0.1,
|
||||
max_running_seconds=3600,
|
||||
)
|
||||
stop = asyncio.Event()
|
||||
|
||||
async def _monitor() -> None:
|
||||
"""Wait until the job lands in a terminal state, then stop the worker."""
|
||||
|
||||
for _ in range(50): # 5 seconds budget
|
||||
await asyncio.sleep(0.1)
|
||||
async with session_factory() as session:
|
||||
current = await jobs_repo.get(session, job.job_id)
|
||||
if current is not None and current.status in ("done", "error"):
|
||||
stop.set()
|
||||
return
|
||||
stop.set() # timeout — let the worker exit so assertions run
|
||||
|
||||
await asyncio.gather(worker.run(stop), _monitor())
|
||||
|
||||
async with session_factory() as session:
|
||||
final = await jobs_repo.get(session, job.job_id)
|
||||
assert final is not None
|
||||
assert final.status == "done"
|
||||
assert final.finished_at is not None
|
||||
|
||||
|
||||
async def test_worker_pipeline_exception_marks_error(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""A step raising a non-IX exception → status=error, response carries the
|
||||
code. The worker catches what the pipeline doesn't."""
|
||||
|
||||
job = await _insert_pending(session_factory)
|
||||
|
||||
worker = Worker(
|
||||
session_factory=session_factory,
|
||||
pipeline_factory=_bad_factory,
|
||||
poll_interval_seconds=0.1,
|
||||
max_running_seconds=3600,
|
||||
)
|
||||
stop = asyncio.Event()
|
||||
|
||||
async def _monitor() -> None:
|
||||
for _ in range(50):
|
||||
await asyncio.sleep(0.1)
|
||||
async with session_factory() as session:
|
||||
current = await jobs_repo.get(session, job.job_id)
|
||||
if current is not None and current.status == "error":
|
||||
stop.set()
|
||||
return
|
||||
stop.set()
|
||||
|
||||
await asyncio.gather(worker.run(stop), _monitor())
|
||||
|
||||
async with session_factory() as session:
|
||||
final = await jobs_repo.get(session, job.job_id)
|
||||
assert final is not None
|
||||
assert final.status == "error"
|
||||
assert final.response is not None
|
||||
assert (final.response.error or "").startswith("IX_002_000")
|
||||
|
||||
|
||||
async def test_worker_delivers_callback(
|
||||
httpx_mock: HTTPXMock,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""callback_url on a done job → one POST, callback_status=delivered."""
|
||||
|
||||
httpx_mock.add_response(url="http://caller/webhook", status_code=200)
|
||||
|
||||
job = await _insert_pending(session_factory, cb="http://caller/webhook")
|
||||
|
||||
worker = Worker(
|
||||
session_factory=session_factory,
|
||||
pipeline_factory=_ok_factory,
|
||||
poll_interval_seconds=0.1,
|
||||
max_running_seconds=3600,
|
||||
callback_timeout_seconds=5,
|
||||
)
|
||||
stop = asyncio.Event()
|
||||
|
||||
async def _monitor() -> None:
|
||||
for _ in range(80):
|
||||
await asyncio.sleep(0.1)
|
||||
async with session_factory() as session:
|
||||
current = await jobs_repo.get(session, job.job_id)
|
||||
if (
|
||||
current is not None
|
||||
and current.status == "done"
|
||||
and current.callback_status is not None
|
||||
):
|
||||
stop.set()
|
||||
return
|
||||
stop.set()
|
||||
|
||||
await asyncio.gather(worker.run(stop), _monitor())
|
||||
|
||||
async with session_factory() as session:
|
||||
final = await jobs_repo.get(session, job.job_id)
|
||||
assert final is not None
|
||||
assert final.callback_status == "delivered"
|
||||
|
||||
|
||||
async def test_worker_marks_callback_failed_on_5xx(
|
||||
httpx_mock: HTTPXMock,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
httpx_mock.add_response(url="http://caller/bad", status_code=500)
|
||||
|
||||
job = await _insert_pending(session_factory, cb="http://caller/bad")
|
||||
|
||||
worker = Worker(
|
||||
session_factory=session_factory,
|
||||
pipeline_factory=_ok_factory,
|
||||
poll_interval_seconds=0.1,
|
||||
max_running_seconds=3600,
|
||||
callback_timeout_seconds=5,
|
||||
)
|
||||
stop = asyncio.Event()
|
||||
|
||||
async def _monitor() -> None:
|
||||
for _ in range(80):
|
||||
await asyncio.sleep(0.1)
|
||||
async with session_factory() as session:
|
||||
current = await jobs_repo.get(session, job.job_id)
|
||||
if (
|
||||
current is not None
|
||||
and current.status == "done"
|
||||
and current.callback_status is not None
|
||||
):
|
||||
stop.set()
|
||||
return
|
||||
stop.set()
|
||||
|
||||
await asyncio.gather(worker.run(stop), _monitor())
|
||||
|
||||
async with session_factory() as session:
|
||||
final = await jobs_repo.get(session, job.job_id)
|
||||
assert final is not None
|
||||
assert final.status == "done" # terminal state stays done
|
||||
assert final.callback_status == "failed"
|
||||
|
||||
|
||||
async def test_worker_sweeps_orphans_at_startup(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""Stale running rows → pending before the loop starts picking work."""
|
||||
|
||||
# Insert a job and backdate it to mimic a crashed worker mid-run.
|
||||
job = await _insert_pending(session_factory, rid="orphan")
|
||||
|
||||
async with session_factory() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
stale = datetime.now(UTC) - timedelta(hours=2)
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE ix_jobs SET status='running', started_at=:t "
|
||||
"WHERE job_id=:jid"
|
||||
),
|
||||
{"t": stale, "jid": job.job_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
worker = Worker(
|
||||
session_factory=session_factory,
|
||||
pipeline_factory=_ok_factory,
|
||||
poll_interval_seconds=0.1,
|
||||
# max_running_seconds=60 so our 2-hour-old row gets swept.
|
||||
max_running_seconds=60,
|
||||
)
|
||||
stop = asyncio.Event()
|
||||
|
||||
async def _monitor() -> None:
|
||||
for _ in range(80):
|
||||
await asyncio.sleep(0.1)
|
||||
async with session_factory() as session:
|
||||
current = await jobs_repo.get(session, job.job_id)
|
||||
if current is not None and current.status == "done":
|
||||
stop.set()
|
||||
return
|
||||
stop.set()
|
||||
|
||||
await asyncio.gather(worker.run(stop), _monitor())
|
||||
|
||||
async with session_factory() as session:
|
||||
final = await jobs_repo.get(session, job.job_id)
|
||||
assert final is not None
|
||||
assert final.status == "done"
|
||||
# attempts starts at 0, gets +1 on sweep.
|
||||
assert final.attempts >= 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("non_matching_url", ["http://x/y", None])
|
||||
async def test_worker_no_callback_leaves_callback_status_none(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
httpx_mock: HTTPXMock,
|
||||
non_matching_url: str | None,
|
||||
) -> None:
|
||||
"""Jobs without a callback_url should never get a callback_status set."""
|
||||
|
||||
if non_matching_url is not None:
|
||||
# If we ever accidentally deliver, pytest-httpx will complain because
|
||||
# no mock matches — which is the signal we want.
|
||||
pass
|
||||
|
||||
job = await _insert_pending(session_factory) # cb=None by default
|
||||
|
||||
worker = Worker(
|
||||
session_factory=session_factory,
|
||||
pipeline_factory=_ok_factory,
|
||||
poll_interval_seconds=0.1,
|
||||
max_running_seconds=3600,
|
||||
)
|
||||
stop = asyncio.Event()
|
||||
|
||||
async def _monitor() -> None:
|
||||
for _ in range(50):
|
||||
await asyncio.sleep(0.1)
|
||||
async with session_factory() as session:
|
||||
current = await jobs_repo.get(session, job.job_id)
|
||||
if current is not None and current.status == "done":
|
||||
stop.set()
|
||||
return
|
||||
stop.set()
|
||||
|
||||
await asyncio.gather(worker.run(stop), _monitor())
|
||||
|
||||
async with session_factory() as session:
|
||||
final = await jobs_repo.get(session, job.job_id)
|
||||
assert final is not None
|
||||
assert final.callback_status is None
|
||||
|
||||
|
||||
def _unused() -> None:
|
||||
"""Silence a ruff F401 for ResponseIX — kept for symmetry w/ other tests."""
|
||||
|
||||
_ = ResponseIX
|
||||
0
tests/live/__init__.py
Normal file
0
tests/live/__init__.py
Normal file
70
tests/live/test_ollama_client_live.py
Normal file
70
tests/live/test_ollama_client_live.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Live tests for :class:`OllamaClient` — gated on ``IX_TEST_OLLAMA=1``.
|
||||
|
||||
Never runs in CI (Forgejo runner has no LAN access to Ollama). Run locally::
|
||||
|
||||
IX_TEST_OLLAMA=1 uv run pytest tests/live/test_ollama_client_live.py -v
|
||||
|
||||
Assumes the Ollama server at ``http://192.168.68.42:11434`` already has
|
||||
``qwen3:14b`` pulled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.genai.ollama_client import OllamaClient
|
||||
from ix.use_cases.bank_statement_header import BankStatementHeader
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.live,
|
||||
pytest.mark.skipif(
|
||||
os.environ.get("IX_TEST_OLLAMA") != "1",
|
||||
reason="live: IX_TEST_OLLAMA=1 required",
|
||||
),
|
||||
]
|
||||
|
||||
_OLLAMA_URL = "http://192.168.68.42:11434"
|
||||
_MODEL = "qwen3:14b"
|
||||
|
||||
|
||||
async def test_structured_output_round_trip() -> None:
|
||||
"""Real Ollama returns a parsed BankStatementHeader instance."""
|
||||
client = OllamaClient(base_url=_OLLAMA_URL, per_call_timeout_s=300.0)
|
||||
result = await client.invoke(
|
||||
request_kwargs={
|
||||
"model": _MODEL,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You extract bank statement header fields. "
|
||||
"Return valid JSON matching the given schema. "
|
||||
"Do not invent values."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Bank: Deutsche Kreditbank (DKB)\n"
|
||||
"Currency: EUR\n"
|
||||
"IBAN: DE89370400440532013000\n"
|
||||
"Period: 2025-01-01 to 2025-01-31"
|
||||
),
|
||||
},
|
||||
],
|
||||
},
|
||||
response_schema=BankStatementHeader,
|
||||
)
|
||||
assert isinstance(result.parsed, BankStatementHeader)
|
||||
assert isinstance(result.parsed.bank_name, str)
|
||||
assert result.parsed.bank_name # non-empty
|
||||
assert isinstance(result.parsed.currency, str)
|
||||
assert result.model_name # server echoes a model name
|
||||
|
||||
|
||||
async def test_selfcheck_ok_against_real_server() -> None:
|
||||
"""``selfcheck`` returns ``ok`` when the target model is pulled."""
|
||||
client = OllamaClient(base_url=_OLLAMA_URL, per_call_timeout_s=5.0)
|
||||
assert await client.selfcheck(expected_model=_MODEL) == "ok"
|
||||
83
tests/live/test_surya_client_live.py
Normal file
83
tests/live/test_surya_client_live.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Live test for :class:`SuryaOCRClient` — gated on ``IX_TEST_OLLAMA=1``.
|
||||
|
||||
Downloads real Surya models (hundreds of MB) on first run. Never runs in
|
||||
CI. Exercised locally with::
|
||||
|
||||
IX_TEST_OLLAMA=1 uv run pytest tests/live/test_surya_client_live.py -v
|
||||
|
||||
Note: requires the ``[ocr]`` extra — ``uv sync --extra ocr --extra dev``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.contracts import Page
|
||||
from ix.segmentation import PageMetadata
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.live,
|
||||
pytest.mark.skipif(
|
||||
os.environ.get("IX_TEST_OLLAMA") != "1",
|
||||
reason="live: IX_TEST_OLLAMA=1 required",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def test_extracts_dkb_and_iban_from_synthetic_giro() -> None:
|
||||
"""Real Surya run against ``tests/fixtures/synthetic_giro.pdf``.
|
||||
|
||||
Assert the flat text contains ``"DKB"`` and the IBAN without spaces.
|
||||
"""
|
||||
from ix.ocr.surya_client import SuryaOCRClient
|
||||
|
||||
fixture = Path(__file__).parent.parent / "fixtures" / "synthetic_giro.pdf"
|
||||
assert fixture.exists(), f"missing fixture: {fixture}"
|
||||
|
||||
# Build Pages the way DocumentIngestor would for this PDF: count pages
|
||||
# via PyMuPDF so we pass the right number of inputs.
|
||||
import fitz
|
||||
|
||||
doc = fitz.open(str(fixture))
|
||||
try:
|
||||
pages = [
|
||||
Page(
|
||||
page_no=i + 1,
|
||||
width=float(p.rect.width),
|
||||
height=float(p.rect.height),
|
||||
lines=[],
|
||||
)
|
||||
for i, p in enumerate(doc)
|
||||
]
|
||||
finally:
|
||||
doc.close()
|
||||
|
||||
client = SuryaOCRClient()
|
||||
result = await client.ocr(
|
||||
pages,
|
||||
files=[(fixture, "application/pdf")],
|
||||
page_metadata=[PageMetadata(file_index=0) for _ in pages],
|
||||
)
|
||||
|
||||
flat_text = result.result.text or ""
|
||||
# Join page-level line texts if flat text missing (shape-safety).
|
||||
if not flat_text:
|
||||
flat_text = "\n".join(
|
||||
line.text or ""
|
||||
for page in result.result.pages
|
||||
for line in page.lines
|
||||
)
|
||||
|
||||
assert "DKB" in flat_text
|
||||
assert "DE89370400440532013000" in flat_text.replace(" ", "")
|
||||
|
||||
|
||||
async def test_selfcheck_ok_against_real_predictors() -> None:
|
||||
"""``selfcheck()`` returns ``ok`` once Surya's predictors load."""
|
||||
from ix.ocr.surya_client import SuryaOCRClient
|
||||
|
||||
client = SuryaOCRClient()
|
||||
assert await client.selfcheck() == "ok"
|
||||
111
tests/unit/test_alembic_smoke.py
Normal file
111
tests/unit/test_alembic_smoke.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Hermetic smoke test for the Alembic migration module.
|
||||
|
||||
Does NOT run ``alembic upgrade head`` — that requires a real database and is
|
||||
exercised by the integration suite. This test only verifies the migration
|
||||
module's structural integrity:
|
||||
|
||||
* the initial migration can be imported without side effects,
|
||||
* its revision / down_revision pair is well-formed,
|
||||
* ``upgrade()`` and ``downgrade()`` are callable,
|
||||
* the SQL emitted by ``upgrade()`` mentions every column the spec requires.
|
||||
|
||||
We capture emitted SQL via ``alembic.op`` in offline mode so we don't need a
|
||||
live connection. The point is that callers can look at this one test and know
|
||||
the migration won't silently drift from spec §4 at import time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
ALEMBIC_DIR = Path(__file__).resolve().parents[2] / "alembic"
|
||||
INITIAL_PATH = ALEMBIC_DIR / "versions" / "001_initial_ix_jobs.py"
|
||||
|
||||
|
||||
def _load_migration_module(path: Path):
|
||||
spec = importlib.util.spec_from_file_location(f"_test_migration_{path.stem}", path)
|
||||
assert spec is not None and spec.loader is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_initial_migration_file_exists() -> None:
|
||||
assert INITIAL_PATH.exists(), f"missing migration: {INITIAL_PATH}"
|
||||
|
||||
|
||||
def test_initial_migration_revision_ids() -> None:
|
||||
module = _load_migration_module(INITIAL_PATH)
|
||||
# Revision must be a non-empty string; down_revision must be None for the
|
||||
# initial migration (no parent).
|
||||
assert isinstance(module.revision, str) and module.revision
|
||||
assert module.down_revision is None
|
||||
|
||||
|
||||
def test_initial_migration_has_upgrade_and_downgrade() -> None:
|
||||
module = _load_migration_module(INITIAL_PATH)
|
||||
assert callable(module.upgrade)
|
||||
assert callable(module.downgrade)
|
||||
|
||||
|
||||
def test_initial_migration_source_mentions_required_columns() -> None:
|
||||
"""Spec §4 columns must all appear in the migration source.
|
||||
|
||||
We grep the source file rather than running the migration because running
|
||||
it needs Postgres. This is belt-and-braces: if someone renames a column
|
||||
they'll see this test fail and go update both sides in lockstep.
|
||||
"""
|
||||
|
||||
source = INITIAL_PATH.read_text(encoding="utf-8")
|
||||
for column in (
|
||||
"job_id",
|
||||
"ix_id",
|
||||
"client_id",
|
||||
"request_id",
|
||||
"status",
|
||||
"request",
|
||||
"response",
|
||||
"callback_url",
|
||||
"callback_status",
|
||||
"attempts",
|
||||
"created_at",
|
||||
"started_at",
|
||||
"finished_at",
|
||||
):
|
||||
assert column in source, f"migration missing column {column!r}"
|
||||
|
||||
|
||||
def test_initial_migration_source_mentions_indexes_and_constraint() -> None:
|
||||
source = INITIAL_PATH.read_text(encoding="utf-8")
|
||||
# Unique correlation index on (client_id, request_id).
|
||||
assert "ix_jobs_client_request" in source
|
||||
# Partial index on pending rows for the claim query.
|
||||
assert "ix_jobs_status_created" in source
|
||||
# CHECK constraint on status values.
|
||||
assert "pending" in source and "running" in source
|
||||
assert "done" in source and "error" in source
|
||||
|
||||
|
||||
def test_models_module_declares_ix_job() -> None:
|
||||
"""The ORM model mirrors the migration; both must stay in sync."""
|
||||
|
||||
from ix.store.models import Base, IxJob
|
||||
|
||||
assert IxJob.__tablename__ == "ix_jobs"
|
||||
# Registered in the shared Base.metadata so alembic autogenerate could
|
||||
# in principle see it — we don't rely on autogenerate, but having the
|
||||
# model in the shared metadata is what lets integration tests do
|
||||
# ``Base.metadata.create_all`` as a fast path when Alembic isn't desired.
|
||||
assert "ix_jobs" in Base.metadata.tables
|
||||
|
||||
|
||||
def test_engine_module_exposes_factory() -> None:
|
||||
from ix.store.engine import get_engine, reset_engine
|
||||
|
||||
# The engine factory is lazy and idempotent. We don't actually call
|
||||
# ``get_engine()`` here — that would require IX_POSTGRES_URL and a real
|
||||
# DB. Just confirm the symbols exist and ``reset_engine`` is safe to call
|
||||
# on a cold cache.
|
||||
assert callable(get_engine)
|
||||
reset_engine() # no-op when nothing is cached
|
||||
104
tests/unit/test_app_wiring.py
Normal file
104
tests/unit/test_app_wiring.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Tests for ``ix.app`` lifespan / probe wiring (Task 4.3).
|
||||
|
||||
The lifespan selects fake clients when ``IX_TEST_MODE=fake`` and exposes
|
||||
their probes via the route DI hook. These tests exercise the probe
|
||||
adapter in isolation — no DB, no real Ollama/Surya.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from ix.app import _make_ocr_probe, _make_ollama_probe, build_pipeline
|
||||
from ix.config import AppConfig
|
||||
from ix.genai.fake import FakeGenAIClient
|
||||
from ix.ocr.fake import FakeOCRClient
|
||||
from ix.pipeline.genai_step import GenAIStep
|
||||
from ix.pipeline.ocr_step import OCRStep
|
||||
from ix.pipeline.pipeline import Pipeline
|
||||
from ix.pipeline.reliability_step import ReliabilityStep
|
||||
from ix.pipeline.response_handler_step import ResponseHandlerStep
|
||||
from ix.pipeline.setup_step import SetupStep
|
||||
|
||||
|
||||
def _cfg(**overrides: object) -> AppConfig:
|
||||
return AppConfig(_env_file=None, **overrides) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class _SelfcheckOllamaClient:
|
||||
async def invoke(self, *a: object, **kw: object) -> object:
|
||||
raise NotImplementedError
|
||||
|
||||
async def selfcheck(
|
||||
self, expected_model: str
|
||||
) -> Literal["ok", "degraded", "fail"]:
|
||||
self.called_with = expected_model
|
||||
return "ok"
|
||||
|
||||
|
||||
class _SelfcheckOCRClient:
|
||||
async def ocr(self, *a: object, **kw: object) -> object:
|
||||
raise NotImplementedError
|
||||
|
||||
async def selfcheck(self) -> Literal["ok", "fail"]:
|
||||
return "ok"
|
||||
|
||||
|
||||
class _BrokenSelfcheckOllama:
|
||||
async def invoke(self, *a: object, **kw: object) -> object:
|
||||
raise NotImplementedError
|
||||
|
||||
async def selfcheck(
|
||||
self, expected_model: str
|
||||
) -> Literal["ok", "degraded", "fail"]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
class TestOllamaProbe:
|
||||
def test_fake_client_without_selfcheck_reports_ok(self) -> None:
|
||||
cfg = _cfg(test_mode="fake", default_model="gpt-oss:20b")
|
||||
probe = _make_ollama_probe(FakeGenAIClient(parsed=None), cfg)
|
||||
assert probe() == "ok"
|
||||
|
||||
def test_real_selfcheck_returns_its_verdict(self) -> None:
|
||||
cfg = _cfg(default_model="gpt-oss:20b")
|
||||
client = _SelfcheckOllamaClient()
|
||||
probe = _make_ollama_probe(client, cfg) # type: ignore[arg-type]
|
||||
assert probe() == "ok"
|
||||
assert client.called_with == "gpt-oss:20b"
|
||||
|
||||
def test_selfcheck_exception_falls_back_to_fail(self) -> None:
|
||||
cfg = _cfg(default_model="gpt-oss:20b")
|
||||
probe = _make_ollama_probe(_BrokenSelfcheckOllama(), cfg) # type: ignore[arg-type]
|
||||
assert probe() == "fail"
|
||||
|
||||
|
||||
class TestOCRProbe:
|
||||
def test_fake_client_without_selfcheck_reports_ok(self) -> None:
|
||||
from ix.contracts.response import OCRDetails, OCRResult
|
||||
|
||||
probe = _make_ocr_probe(FakeOCRClient(canned=OCRResult(result=OCRDetails())))
|
||||
assert probe() == "ok"
|
||||
|
||||
def test_real_selfcheck_returns_its_verdict(self) -> None:
|
||||
probe = _make_ocr_probe(_SelfcheckOCRClient()) # type: ignore[arg-type]
|
||||
assert probe() == "ok"
|
||||
|
||||
|
||||
class TestBuildPipeline:
|
||||
def test_assembles_all_five_steps_in_order(self) -> None:
|
||||
from ix.contracts.response import OCRDetails, OCRResult
|
||||
|
||||
genai = FakeGenAIClient(parsed=None)
|
||||
ocr = FakeOCRClient(canned=OCRResult(result=OCRDetails()))
|
||||
cfg = _cfg(test_mode="fake")
|
||||
pipeline = build_pipeline(genai, ocr, cfg)
|
||||
assert isinstance(pipeline, Pipeline)
|
||||
steps = pipeline._steps # type: ignore[attr-defined]
|
||||
assert [type(s) for s in steps] == [
|
||||
SetupStep,
|
||||
OCRStep,
|
||||
GenAIStep,
|
||||
ReliabilityStep,
|
||||
ResponseHandlerStep,
|
||||
]
|
||||
131
tests/unit/test_config.py
Normal file
131
tests/unit/test_config.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
"""Tests for :mod:`ix.config` — the pydantic-settings ``AppConfig``.
|
||||
|
||||
Guardrails we care about:
|
||||
|
||||
1. Every env var in spec §9 round-trips with the right type.
|
||||
2. Defaults match the spec exactly when no env is set.
|
||||
3. Unknown IX_ vars are ignored (``extra="ignore"``) so a typo doesn't crash
|
||||
the container at startup.
|
||||
4. ``get_config()`` is cached — same instance per process — and
|
||||
``get_config.cache_clear()`` rebuilds from the current environment (used by
|
||||
every test here to keep them independent of process state).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.config import AppConfig, get_config
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_config_cache() -> None:
|
||||
"""Flush the LRU cache around every test.
|
||||
|
||||
Without this, tests that set env vars would see stale data from earlier
|
||||
runs because ``get_config()`` caches the first materialised instance.
|
||||
"""
|
||||
|
||||
get_config.cache_clear()
|
||||
|
||||
|
||||
def _clear_ix_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Scrub every IX_* var so defaults surface predictably.
|
||||
|
||||
Tests that exercise env-based overrides still call ``monkeypatch.setenv``
|
||||
after this to dial in specific values; tests for defaults rely on this
|
||||
scrubbing so a developer's local ``.env`` can't contaminate the assertion.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
for key in list(os.environ):
|
||||
if key.startswith("IX_"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def test_defaults_match_spec(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_ix_env(monkeypatch)
|
||||
# Don't let pydantic-settings pick up the repo's .env.example.
|
||||
cfg = AppConfig(_env_file=None) # type: ignore[call-arg]
|
||||
|
||||
assert cfg.postgres_url == (
|
||||
"postgresql+asyncpg://infoxtractor:<password>"
|
||||
"@127.0.0.1:5431/infoxtractor"
|
||||
)
|
||||
assert cfg.ollama_url == "http://127.0.0.1:11434"
|
||||
assert cfg.default_model == "qwen3:14b"
|
||||
assert cfg.ocr_engine == "surya"
|
||||
assert cfg.tmp_dir == "/tmp/ix"
|
||||
assert cfg.pipeline_worker_concurrency == 1
|
||||
assert cfg.pipeline_request_timeout_seconds == 2700
|
||||
assert cfg.genai_call_timeout_seconds == 1500
|
||||
assert cfg.file_max_bytes == 52428800
|
||||
assert cfg.file_connect_timeout_seconds == 10
|
||||
assert cfg.file_read_timeout_seconds == 30
|
||||
assert cfg.render_max_pixels_per_page == 75000000
|
||||
assert cfg.log_level == "INFO"
|
||||
assert cfg.callback_timeout_seconds == 10
|
||||
|
||||
|
||||
def test_env_overrides(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_ix_env(monkeypatch)
|
||||
monkeypatch.setenv("IX_POSTGRES_URL", "postgresql+asyncpg://u:p@db:5432/x")
|
||||
monkeypatch.setenv("IX_OLLAMA_URL", "http://llm:11434")
|
||||
monkeypatch.setenv("IX_DEFAULT_MODEL", "llama3:8b")
|
||||
monkeypatch.setenv("IX_PIPELINE_WORKER_CONCURRENCY", "4")
|
||||
monkeypatch.setenv("IX_GENAI_CALL_TIMEOUT_SECONDS", "60")
|
||||
monkeypatch.setenv("IX_LOG_LEVEL", "DEBUG")
|
||||
monkeypatch.setenv("IX_CALLBACK_TIMEOUT_SECONDS", "30")
|
||||
|
||||
cfg = AppConfig(_env_file=None) # type: ignore[call-arg]
|
||||
|
||||
assert cfg.postgres_url == "postgresql+asyncpg://u:p@db:5432/x"
|
||||
assert cfg.ollama_url == "http://llm:11434"
|
||||
assert cfg.default_model == "llama3:8b"
|
||||
assert cfg.pipeline_worker_concurrency == 4
|
||||
assert cfg.genai_call_timeout_seconds == 60
|
||||
assert cfg.log_level == "DEBUG"
|
||||
assert cfg.callback_timeout_seconds == 30
|
||||
|
||||
|
||||
def test_get_config_is_cached(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_ix_env(monkeypatch)
|
||||
monkeypatch.setenv("IX_POSTGRES_URL", "postgresql+asyncpg://a:b@c:5432/d1")
|
||||
first = get_config()
|
||||
# Later mutation must NOT be seen until cache_clear — this is a feature,
|
||||
# not a bug: config is process-level state, not per-call.
|
||||
monkeypatch.setenv("IX_POSTGRES_URL", "postgresql+asyncpg://a:b@c:5432/d2")
|
||||
second = get_config()
|
||||
assert first is second
|
||||
assert second.postgres_url.endswith("/d1")
|
||||
|
||||
get_config.cache_clear()
|
||||
third = get_config()
|
||||
assert third is not first
|
||||
assert third.postgres_url.endswith("/d2")
|
||||
|
||||
|
||||
def test_extra_env_keys_are_ignored(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""A typo'd IX_FOOBAR should not raise ValidationError at startup."""
|
||||
_clear_ix_env(monkeypatch)
|
||||
monkeypatch.setenv("IX_FOOBAR", "whatever")
|
||||
# Should not raise.
|
||||
cfg = AppConfig(_env_file=None) # type: ignore[call-arg]
|
||||
assert cfg.ollama_url.startswith("http://")
|
||||
|
||||
|
||||
def test_engine_uses_config_url(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""``ix.store.engine`` reads the URL through ``AppConfig``.
|
||||
|
||||
Task 3.2 refactors engine.py to go through ``get_config()`` instead of
|
||||
reading ``os.environ`` directly. We can't actually construct an async
|
||||
engine in a unit test (would need the DB), so we verify the resolution
|
||||
function exists and returns the configured URL.
|
||||
"""
|
||||
_clear_ix_env(monkeypatch)
|
||||
monkeypatch.setenv("IX_POSTGRES_URL", "postgresql+asyncpg://a:b@c:5432/d")
|
||||
|
||||
from ix.store.engine import _resolve_url
|
||||
|
||||
assert _resolve_url() == "postgresql+asyncpg://a:b@c:5432/d"
|
||||
|
|
@ -31,6 +31,7 @@ from ix.contracts import (
|
|||
ResponseIX,
|
||||
SegmentCitation,
|
||||
)
|
||||
from ix.contracts.request import InlineUseCase, UseCaseFieldDef
|
||||
|
||||
|
||||
class TestFileRef:
|
||||
|
|
@ -49,6 +50,24 @@ class TestFileRef:
|
|||
assert fr.headers == {"Authorization": "Token abc"}
|
||||
assert fr.max_bytes == 1_000_000
|
||||
|
||||
def test_display_name_defaults_to_none(self) -> None:
|
||||
fr = FileRef(url="file:///tmp/ix/ui/abc.pdf")
|
||||
assert fr.display_name is None
|
||||
|
||||
def test_display_name_roundtrip(self) -> None:
|
||||
fr = FileRef(
|
||||
url="file:///tmp/ix/ui/abc.pdf",
|
||||
display_name="my statement.pdf",
|
||||
)
|
||||
assert fr.display_name == "my statement.pdf"
|
||||
dumped = fr.model_dump_json()
|
||||
rt = FileRef.model_validate_json(dumped)
|
||||
assert rt.display_name == "my statement.pdf"
|
||||
# Backward-compat: a serialised FileRef without display_name still
|
||||
# validates cleanly (older stored jobs predate the field).
|
||||
legacy = FileRef.model_validate({"url": "file:///x.pdf"})
|
||||
assert legacy.display_name is None
|
||||
|
||||
|
||||
class TestOptionDefaults:
|
||||
def test_ocr_defaults_match_spec(self) -> None:
|
||||
|
|
@ -182,6 +201,32 @@ class TestRequestIX:
|
|||
with pytest.raises(ValidationError):
|
||||
RequestIX.model_validate({"use_case": "x"})
|
||||
|
||||
def test_use_case_inline_defaults_to_none(self) -> None:
|
||||
r = RequestIX(**self._minimal_payload())
|
||||
assert r.use_case_inline is None
|
||||
|
||||
def test_use_case_inline_roundtrip(self) -> None:
|
||||
payload = self._minimal_payload()
|
||||
payload["use_case_inline"] = {
|
||||
"use_case_name": "adhoc",
|
||||
"system_prompt": "extract stuff",
|
||||
"fields": [
|
||||
{"name": "a", "type": "str", "required": True},
|
||||
{"name": "b", "type": "int"},
|
||||
],
|
||||
}
|
||||
r = RequestIX.model_validate(payload)
|
||||
assert r.use_case_inline is not None
|
||||
assert isinstance(r.use_case_inline, InlineUseCase)
|
||||
assert r.use_case_inline.use_case_name == "adhoc"
|
||||
assert len(r.use_case_inline.fields) == 2
|
||||
assert isinstance(r.use_case_inline.fields[0], UseCaseFieldDef)
|
||||
# Round-trip through JSON
|
||||
dumped = r.model_dump_json()
|
||||
r2 = RequestIX.model_validate_json(dumped)
|
||||
assert r2.use_case_inline is not None
|
||||
assert r2.use_case_inline.fields[1].type == "int"
|
||||
|
||||
|
||||
class TestOCRResult:
|
||||
def test_minimal_defaults(self) -> None:
|
||||
|
|
|
|||
60
tests/unit/test_factories.py
Normal file
60
tests/unit/test_factories.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""Tests for the GenAI + OCR factories (Task 4.3).
|
||||
|
||||
The factories pick between fake and real clients based on
|
||||
``IX_TEST_MODE``. CI runs with ``IX_TEST_MODE=fake``, production runs
|
||||
without — so the selection knob is the one lever between hermetic CI and
|
||||
real clients.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.config import AppConfig
|
||||
from ix.genai import make_genai_client
|
||||
from ix.genai.fake import FakeGenAIClient
|
||||
from ix.genai.ollama_client import OllamaClient
|
||||
from ix.ocr import make_ocr_client
|
||||
from ix.ocr.fake import FakeOCRClient
|
||||
from ix.ocr.surya_client import SuryaOCRClient
|
||||
|
||||
|
||||
def _cfg(**overrides: object) -> AppConfig:
|
||||
"""Build an AppConfig without loading the repo's .env.example."""
|
||||
return AppConfig(_env_file=None, **overrides) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class TestGenAIFactory:
|
||||
def test_fake_mode_returns_fake(self) -> None:
|
||||
cfg = _cfg(test_mode="fake")
|
||||
client = make_genai_client(cfg)
|
||||
assert isinstance(client, FakeGenAIClient)
|
||||
|
||||
def test_production_returns_ollama_with_configured_url(self) -> None:
|
||||
cfg = _cfg(
|
||||
test_mode=None,
|
||||
ollama_url="http://ollama.host:11434",
|
||||
genai_call_timeout_seconds=42,
|
||||
)
|
||||
client = make_genai_client(cfg)
|
||||
assert isinstance(client, OllamaClient)
|
||||
# Inspect the private attrs for binding correctness.
|
||||
assert client._base_url == "http://ollama.host:11434"
|
||||
assert client._per_call_timeout_s == 42
|
||||
|
||||
|
||||
class TestOCRFactory:
|
||||
def test_fake_mode_returns_fake(self) -> None:
|
||||
cfg = _cfg(test_mode="fake")
|
||||
client = make_ocr_client(cfg)
|
||||
assert isinstance(client, FakeOCRClient)
|
||||
|
||||
def test_production_surya_returns_surya(self) -> None:
|
||||
cfg = _cfg(test_mode=None, ocr_engine="surya")
|
||||
client = make_ocr_client(cfg)
|
||||
assert isinstance(client, SuryaOCRClient)
|
||||
|
||||
def test_unknown_engine_raises(self) -> None:
|
||||
cfg = _cfg(test_mode=None, ocr_engine="tesseract")
|
||||
import pytest
|
||||
|
||||
with pytest.raises(ValueError, match="ocr_engine"):
|
||||
make_ocr_client(cfg)
|
||||
55
tests/unit/test_genai_fake.py
Normal file
55
tests/unit/test_genai_fake.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""Tests for GenAIClient Protocol + FakeGenAIClient (spec §6.3)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ix.genai import (
|
||||
FakeGenAIClient,
|
||||
GenAIClient,
|
||||
GenAIInvocationResult,
|
||||
GenAIUsage,
|
||||
)
|
||||
|
||||
|
||||
class _Schema(BaseModel):
|
||||
foo: str
|
||||
bar: int
|
||||
|
||||
|
||||
class TestProtocolConformance:
|
||||
def test_fake_is_runtime_checkable_as_protocol(self) -> None:
|
||||
client = FakeGenAIClient(parsed=_Schema(foo="x", bar=1))
|
||||
assert isinstance(client, GenAIClient)
|
||||
|
||||
|
||||
class TestReturnsCannedResult:
|
||||
async def test_defaults_populate_usage_and_model(self) -> None:
|
||||
parsed = _Schema(foo="bank", bar=2)
|
||||
client = FakeGenAIClient(parsed=parsed)
|
||||
result = await client.invoke(request_kwargs={"model": "x"}, response_schema=_Schema)
|
||||
assert isinstance(result, GenAIInvocationResult)
|
||||
assert result.parsed is parsed
|
||||
assert isinstance(result.usage, GenAIUsage)
|
||||
assert result.usage.prompt_tokens == 0
|
||||
assert result.usage.completion_tokens == 0
|
||||
assert result.model_name == "fake"
|
||||
|
||||
async def test_explicit_usage_and_model_passed_through(self) -> None:
|
||||
parsed = _Schema(foo="k", bar=3)
|
||||
usage = GenAIUsage(prompt_tokens=10, completion_tokens=20)
|
||||
client = FakeGenAIClient(parsed=parsed, usage=usage, model_name="mock:0.1")
|
||||
result = await client.invoke(request_kwargs={}, response_schema=_Schema)
|
||||
assert result.usage is usage
|
||||
assert result.model_name == "mock:0.1"
|
||||
|
||||
|
||||
class TestRaiseOnCallHook:
|
||||
async def test_raise_on_call_propagates(self) -> None:
|
||||
err = RuntimeError("ollama is down")
|
||||
client = FakeGenAIClient(
|
||||
parsed=_Schema(foo="x", bar=1), raise_on_call=err
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="ollama is down"):
|
||||
await client.invoke(request_kwargs={}, response_schema=_Schema)
|
||||
378
tests/unit/test_genai_step.py
Normal file
378
tests/unit/test_genai_step.py
Normal file
|
|
@ -0,0 +1,378 @@
|
|||
"""Tests for :class:`ix.pipeline.genai_step.GenAIStep` (spec §6.3, §7, §9.2)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from ix.contracts import (
|
||||
Context,
|
||||
GenAIOptions,
|
||||
Line,
|
||||
OCRDetails,
|
||||
OCROptions,
|
||||
OCRResult,
|
||||
Options,
|
||||
Page,
|
||||
ProvenanceData,
|
||||
ProvenanceOptions,
|
||||
RequestIX,
|
||||
ResponseIX,
|
||||
SegmentCitation,
|
||||
)
|
||||
from ix.contracts.response import _InternalContext
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.genai import FakeGenAIClient, GenAIInvocationResult, GenAIUsage
|
||||
from ix.pipeline.genai_step import GenAIStep
|
||||
from ix.segmentation import PageMetadata, SegmentIndex
|
||||
from ix.use_cases.bank_statement_header import BankStatementHeader
|
||||
from ix.use_cases.bank_statement_header import Request as BankReq
|
||||
|
||||
|
||||
def _make_request(
|
||||
*,
|
||||
use_ocr: bool = True,
|
||||
ocr_only: bool = False,
|
||||
include_provenance: bool = True,
|
||||
model_name: str | None = None,
|
||||
) -> RequestIX:
|
||||
return RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="test",
|
||||
request_id="r-1",
|
||||
context=Context(files=[], texts=[]),
|
||||
options=Options(
|
||||
ocr=OCROptions(use_ocr=use_ocr, ocr_only=ocr_only),
|
||||
gen_ai=GenAIOptions(gen_ai_model_name=model_name),
|
||||
provenance=ProvenanceOptions(
|
||||
include_provenance=include_provenance,
|
||||
max_sources_per_field=5,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _ocr_with_lines(lines: list[str]) -> OCRResult:
|
||||
return OCRResult(
|
||||
result=OCRDetails(
|
||||
text="\n".join(lines),
|
||||
pages=[
|
||||
Page(
|
||||
page_no=1,
|
||||
width=100.0,
|
||||
height=200.0,
|
||||
lines=[
|
||||
Line(text=t, bounding_box=[0, i * 10, 10, i * 10, 10, i * 10 + 5, 0, i * 10 + 5])
|
||||
for i, t in enumerate(lines)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _response_with_segment_index(
|
||||
lines: list[str], texts: list[str] | None = None
|
||||
) -> ResponseIX:
|
||||
ocr = _ocr_with_lines(lines)
|
||||
resp = ResponseIX(ocr_result=ocr)
|
||||
seg_idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=[PageMetadata(file_index=0)],
|
||||
)
|
||||
resp.context = _InternalContext(
|
||||
use_case_request=BankReq(),
|
||||
use_case_response=BankStatementHeader,
|
||||
segment_index=seg_idx,
|
||||
texts=texts or [],
|
||||
pages=ocr.result.pages,
|
||||
page_metadata=[PageMetadata(file_index=0)],
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class CapturingClient:
|
||||
"""Records the request_kwargs + response_schema handed to invoke()."""
|
||||
|
||||
def __init__(self, parsed: Any) -> None:
|
||||
self._parsed = parsed
|
||||
self.request_kwargs: dict[str, Any] | None = None
|
||||
self.response_schema: type[BaseModel] | None = None
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
request_kwargs: dict[str, Any],
|
||||
response_schema: type[BaseModel],
|
||||
) -> GenAIInvocationResult:
|
||||
self.request_kwargs = request_kwargs
|
||||
self.response_schema = response_schema
|
||||
return GenAIInvocationResult(
|
||||
parsed=self._parsed,
|
||||
usage=GenAIUsage(prompt_tokens=5, completion_tokens=7),
|
||||
model_name="captured-model",
|
||||
)
|
||||
|
||||
|
||||
class TestValidate:
|
||||
async def test_ocr_only_skips(self) -> None:
|
||||
step = GenAIStep(
|
||||
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
|
||||
)
|
||||
req = _make_request(ocr_only=True)
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
assert await step.validate(req, resp) is False
|
||||
|
||||
async def test_empty_context_raises_IX_001_000(self) -> None:
|
||||
step = GenAIStep(
|
||||
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
|
||||
)
|
||||
req = _make_request()
|
||||
resp = ResponseIX(ocr_result=OCRResult(result=OCRDetails(text="")))
|
||||
resp.context = _InternalContext(
|
||||
use_case_request=BankReq(),
|
||||
use_case_response=BankStatementHeader,
|
||||
texts=[],
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.validate(req, resp)
|
||||
assert ei.value.code is IXErrorCode.IX_001_000
|
||||
|
||||
async def test_runs_when_texts_only(self) -> None:
|
||||
step = GenAIStep(
|
||||
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
|
||||
)
|
||||
req = _make_request()
|
||||
resp = ResponseIX(ocr_result=OCRResult(result=OCRDetails(text="")))
|
||||
resp.context = _InternalContext(
|
||||
use_case_request=BankReq(),
|
||||
use_case_response=BankStatementHeader,
|
||||
texts=["some paperless text"],
|
||||
)
|
||||
assert await step.validate(req, resp) is True
|
||||
|
||||
async def test_runs_when_ocr_text_present(self) -> None:
|
||||
step = GenAIStep(
|
||||
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
|
||||
)
|
||||
req = _make_request()
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
assert await step.validate(req, resp) is True
|
||||
|
||||
|
||||
class TestProcessBasic:
|
||||
async def test_writes_ix_result_and_meta(self) -> None:
|
||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
||||
client = CapturingClient(parsed=parsed)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=False)
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
resp = await step.process(req, resp)
|
||||
|
||||
assert resp.ix_result.result["bank_name"] == "DKB"
|
||||
assert resp.ix_result.result["currency"] == "EUR"
|
||||
assert resp.ix_result.meta_data["model_name"] == "captured-model"
|
||||
assert resp.ix_result.meta_data["token_usage"]["prompt_tokens"] == 5
|
||||
assert resp.ix_result.meta_data["token_usage"]["completion_tokens"] == 7
|
||||
|
||||
|
||||
class TestSystemPromptAssembly:
|
||||
async def test_citation_instruction_appended_when_provenance_on(self) -> None:
|
||||
parsed_wrapped: Any = _WrappedResponse(
|
||||
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
|
||||
segment_citations=[],
|
||||
)
|
||||
client = CapturingClient(parsed=parsed_wrapped)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=True)
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
await step.process(req, resp)
|
||||
|
||||
messages = client.request_kwargs["messages"] # type: ignore[index]
|
||||
system = messages[0]["content"]
|
||||
# Use-case system prompt is always there.
|
||||
assert "extract header metadata" in system
|
||||
# Citation instruction added.
|
||||
assert "segment_citations" in system
|
||||
assert "value_segment_ids" in system
|
||||
|
||||
async def test_citation_instruction_absent_when_provenance_off(self) -> None:
|
||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
||||
client = CapturingClient(parsed=parsed)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=False)
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
await step.process(req, resp)
|
||||
|
||||
messages = client.request_kwargs["messages"] # type: ignore[index]
|
||||
system = messages[0]["content"]
|
||||
assert "segment_citations" not in system
|
||||
|
||||
|
||||
class TestUserTextFormat:
|
||||
async def test_tagged_prompt_when_provenance_on(self) -> None:
|
||||
parsed_wrapped: Any = _WrappedResponse(
|
||||
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
|
||||
segment_citations=[],
|
||||
)
|
||||
client = CapturingClient(parsed=parsed_wrapped)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=True)
|
||||
resp = _response_with_segment_index(lines=["alpha line", "beta line"])
|
||||
await step.process(req, resp)
|
||||
|
||||
user_content = client.request_kwargs["messages"][1]["content"] # type: ignore[index]
|
||||
assert "[p1_l0] alpha line" in user_content
|
||||
assert "[p1_l1] beta line" in user_content
|
||||
|
||||
async def test_plain_prompt_when_provenance_off(self) -> None:
|
||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
||||
client = CapturingClient(parsed=parsed)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=False)
|
||||
resp = _response_with_segment_index(lines=["alpha line", "beta line"])
|
||||
await step.process(req, resp)
|
||||
|
||||
user_content = client.request_kwargs["messages"][1]["content"] # type: ignore[index]
|
||||
assert "[p1_l0]" not in user_content
|
||||
assert "alpha line" in user_content
|
||||
assert "beta line" in user_content
|
||||
|
||||
|
||||
class TestResponseSchemaChoice:
|
||||
async def test_plain_schema_when_provenance_off(self) -> None:
|
||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
||||
client = CapturingClient(parsed=parsed)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=False)
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
await step.process(req, resp)
|
||||
assert client.response_schema is BankStatementHeader
|
||||
|
||||
async def test_wrapped_schema_when_provenance_on(self) -> None:
|
||||
parsed_wrapped: Any = _WrappedResponse(
|
||||
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
|
||||
segment_citations=[],
|
||||
)
|
||||
client = CapturingClient(parsed=parsed_wrapped)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=True)
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
await step.process(req, resp)
|
||||
schema = client.response_schema
|
||||
assert schema is not None
|
||||
field_names = set(schema.model_fields.keys())
|
||||
assert field_names == {"result", "segment_citations"}
|
||||
|
||||
|
||||
class TestProvenanceMapping:
|
||||
async def test_provenance_populated_from_citations(self) -> None:
|
||||
parsed_wrapped: Any = _WrappedResponse(
|
||||
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
|
||||
segment_citations=[
|
||||
SegmentCitation(
|
||||
field_path="result.bank_name",
|
||||
value_segment_ids=["p1_l0"],
|
||||
context_segment_ids=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
client = CapturingClient(parsed=parsed_wrapped)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=True)
|
||||
resp = _response_with_segment_index(lines=["DKB"])
|
||||
resp = await step.process(req, resp)
|
||||
|
||||
assert isinstance(resp.provenance, ProvenanceData)
|
||||
fields = resp.provenance.fields
|
||||
assert "result.bank_name" in fields
|
||||
fp = fields["result.bank_name"]
|
||||
assert fp.value == "DKB"
|
||||
assert len(fp.sources) == 1
|
||||
assert fp.sources[0].segment_id == "p1_l0"
|
||||
# Reliability flags are NOT set here — ReliabilityStep does that.
|
||||
assert fp.provenance_verified is None
|
||||
assert fp.text_agreement is None
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
async def test_network_error_maps_to_IX_002_000(self) -> None:
|
||||
err = httpx.ConnectError("refused")
|
||||
client = FakeGenAIClient(
|
||||
parsed=BankStatementHeader(bank_name="x", currency="EUR"),
|
||||
raise_on_call=err,
|
||||
)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=False)
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.process(req, resp)
|
||||
assert ei.value.code is IXErrorCode.IX_002_000
|
||||
|
||||
async def test_timeout_maps_to_IX_002_000(self) -> None:
|
||||
err = httpx.ReadTimeout("slow")
|
||||
client = FakeGenAIClient(
|
||||
parsed=BankStatementHeader(bank_name="x", currency="EUR"),
|
||||
raise_on_call=err,
|
||||
)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=False)
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.process(req, resp)
|
||||
assert ei.value.code is IXErrorCode.IX_002_000
|
||||
|
||||
async def test_validation_error_maps_to_IX_002_001(self) -> None:
|
||||
class _M(BaseModel):
|
||||
x: int
|
||||
|
||||
try:
|
||||
_M(x="not-an-int") # type: ignore[arg-type]
|
||||
except ValidationError as err:
|
||||
raise_err = err
|
||||
|
||||
client = FakeGenAIClient(
|
||||
parsed=BankStatementHeader(bank_name="x", currency="EUR"),
|
||||
raise_on_call=raise_err,
|
||||
)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=False)
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.process(req, resp)
|
||||
assert ei.value.code is IXErrorCode.IX_002_001
|
||||
|
||||
|
||||
class TestModelSelection:
|
||||
async def test_request_model_override_wins(self) -> None:
|
||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
||||
client = CapturingClient(parsed=parsed)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=False, model_name="explicit-model")
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
await step.process(req, resp)
|
||||
assert client.request_kwargs["model"] == "explicit-model" # type: ignore[index]
|
||||
|
||||
async def test_falls_back_to_use_case_default(self) -> None:
|
||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
||||
client = CapturingClient(parsed=parsed)
|
||||
step = GenAIStep(genai_client=client)
|
||||
req = _make_request(include_provenance=False)
|
||||
resp = _response_with_segment_index(lines=["hello"])
|
||||
await step.process(req, resp)
|
||||
# use-case default is qwen3:14b
|
||||
assert client.request_kwargs["model"] == "qwen3:14b" # type: ignore[index]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
||||
|
||||
class _WrappedResponse(BaseModel):
|
||||
"""Stand-in for the runtime-created ProvenanceWrappedResponse."""
|
||||
|
||||
result: BankStatementHeader
|
||||
segment_citations: list[SegmentCitation] = []
|
||||
138
tests/unit/test_ingestion_fetch.py
Normal file
138
tests/unit/test_ingestion_fetch.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""Tests for :func:`ix.ingestion.fetch.fetch_file` (spec §6.1)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pytest_httpx import HTTPXMock
|
||||
|
||||
from ix.contracts import FileRef
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.ingestion import FetchConfig, fetch_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cfg() -> FetchConfig:
|
||||
return FetchConfig(
|
||||
connect_timeout_s=1.0,
|
||||
read_timeout_s=2.0,
|
||||
max_bytes=1024 * 1024,
|
||||
)
|
||||
|
||||
|
||||
class TestSuccessPath:
|
||||
async def test_downloads_with_auth_header_and_writes_to_tmp(
|
||||
self, tmp_path: Path, cfg: FetchConfig, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
url = "https://paperless.local/doc/123/download"
|
||||
httpx_mock.add_response(
|
||||
url=url,
|
||||
method="GET",
|
||||
status_code=200,
|
||||
content=b"%PDF-1.4 body",
|
||||
headers={"content-type": "application/pdf"},
|
||||
)
|
||||
file_ref = FileRef(
|
||||
url=url,
|
||||
headers={"Authorization": "Token abc"},
|
||||
)
|
||||
path = await fetch_file(file_ref, tmp_dir=tmp_path, cfg=cfg)
|
||||
assert path.exists()
|
||||
assert path.read_bytes() == b"%PDF-1.4 body"
|
||||
|
||||
# Confirm header went out.
|
||||
reqs = httpx_mock.get_requests()
|
||||
assert len(reqs) == 1
|
||||
assert reqs[0].headers["Authorization"] == "Token abc"
|
||||
|
||||
|
||||
class TestNon2xx:
|
||||
async def test_404_raises_IX_000_007(
|
||||
self, tmp_path: Path, cfg: FetchConfig, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
url = "https://host.local/missing.pdf"
|
||||
httpx_mock.add_response(url=url, status_code=404, content=b"")
|
||||
file_ref = FileRef(url=url)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await fetch_file(file_ref, tmp_dir=tmp_path, cfg=cfg)
|
||||
assert ei.value.code is IXErrorCode.IX_000_007
|
||||
assert "404" in (ei.value.detail or "")
|
||||
|
||||
async def test_500_raises_IX_000_007(
|
||||
self, tmp_path: Path, cfg: FetchConfig, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
url = "https://host.local/boom.pdf"
|
||||
httpx_mock.add_response(url=url, status_code=500, content=b"oops")
|
||||
file_ref = FileRef(url=url)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await fetch_file(file_ref, tmp_dir=tmp_path, cfg=cfg)
|
||||
assert ei.value.code is IXErrorCode.IX_000_007
|
||||
|
||||
|
||||
class TestTimeout:
|
||||
async def test_timeout_raises_IX_000_007(
|
||||
self, tmp_path: Path, cfg: FetchConfig, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
url = "https://host.local/slow.pdf"
|
||||
httpx_mock.add_exception(httpx.ReadTimeout("slow"), url=url)
|
||||
file_ref = FileRef(url=url)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await fetch_file(file_ref, tmp_dir=tmp_path, cfg=cfg)
|
||||
assert ei.value.code is IXErrorCode.IX_000_007
|
||||
|
||||
|
||||
class TestOversize:
|
||||
async def test_oversize_raises_IX_000_007(
|
||||
self, tmp_path: Path, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
url = "https://host.local/big.pdf"
|
||||
cfg = FetchConfig(
|
||||
connect_timeout_s=1.0,
|
||||
read_timeout_s=2.0,
|
||||
max_bytes=100,
|
||||
)
|
||||
# 500 bytes of payload; cap is 100.
|
||||
httpx_mock.add_response(url=url, status_code=200, content=b"x" * 500)
|
||||
file_ref = FileRef(url=url)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await fetch_file(file_ref, tmp_dir=tmp_path, cfg=cfg)
|
||||
assert ei.value.code is IXErrorCode.IX_000_007
|
||||
|
||||
async def test_per_file_max_bytes_override(
|
||||
self, tmp_path: Path, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
url = "https://host.local/mid.pdf"
|
||||
cfg = FetchConfig(
|
||||
connect_timeout_s=1.0,
|
||||
read_timeout_s=2.0,
|
||||
max_bytes=1_000_000,
|
||||
)
|
||||
# file_ref sets a tighter cap.
|
||||
httpx_mock.add_response(url=url, status_code=200, content=b"x" * 500)
|
||||
file_ref = FileRef(url=url, max_bytes=100)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await fetch_file(file_ref, tmp_dir=tmp_path, cfg=cfg)
|
||||
assert ei.value.code is IXErrorCode.IX_000_007
|
||||
|
||||
|
||||
class TestFileUrl:
|
||||
async def test_file_scheme_reads_local(
|
||||
self, tmp_path: Path, cfg: FetchConfig
|
||||
) -> None:
|
||||
src = tmp_path / "in.pdf"
|
||||
src.write_bytes(b"%PDF-1.4\nfile scheme content")
|
||||
file_ref = FileRef(url=src.as_uri())
|
||||
dst = await fetch_file(file_ref, tmp_dir=tmp_path / "out", cfg=cfg)
|
||||
assert dst.exists()
|
||||
assert dst.read_bytes() == b"%PDF-1.4\nfile scheme content"
|
||||
|
||||
async def test_file_scheme_missing_raises(
|
||||
self, tmp_path: Path, cfg: FetchConfig
|
||||
) -> None:
|
||||
missing = tmp_path / "nope.pdf"
|
||||
file_ref = FileRef(url=missing.as_uri())
|
||||
with pytest.raises(IXException) as ei:
|
||||
await fetch_file(file_ref, tmp_dir=tmp_path, cfg=cfg)
|
||||
assert ei.value.code is IXErrorCode.IX_000_007
|
||||
96
tests/unit/test_ingestion_mime.py
Normal file
96
tests/unit/test_ingestion_mime.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""Tests for MIME sniffing (spec §6.1)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.ingestion import SUPPORTED_MIMES, detect_mime, require_supported
|
||||
|
||||
# Real-header fixtures. python-magic looks at bytes, not extensions, so
|
||||
# these are the smallest valid-byte samples we can produce on the fly.
|
||||
|
||||
_PDF_BYTES = b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n1 0 obj\n<<>>\nendobj\nxref\n0 1\n0000000000 65535 f\ntrailer<<>>\nstartxref\n0\n%%EOF\n"
|
||||
|
||||
_PNG_BYTES = bytes.fromhex(
|
||||
# PNG magic + minimal IHDR + IDAT + IEND (1x1 all-black).
|
||||
"89504e470d0a1a0a"
|
||||
"0000000d49484452"
|
||||
"00000001000000010806000000"
|
||||
"1f15c4890000000d"
|
||||
"49444154789c6300010000000500010d0a2db400000000"
|
||||
"49454e44ae426082"
|
||||
)
|
||||
|
||||
_JPEG_BYTES = (
|
||||
b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00"
|
||||
b"\xff\xdb\x00C\x00" + b"\x08" * 64
|
||||
+ b"\xff\xc0\x00\x0b\x08\x00\x01\x00\x01\x01\x01\x11\x00"
|
||||
+ b"\xff\xc4\x00\x1f\x00\x00\x01\x05\x01\x01\x01\x01\x01\x01\x00\x00"
|
||||
+ b"\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b"
|
||||
+ b"\xff\xda\x00\x08\x01\x01\x00\x00?\x00\xfb\xff\xd9"
|
||||
)
|
||||
|
||||
|
||||
def _make_tiff_bytes() -> bytes:
|
||||
# Tiny valid TIFF via PIL.
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
buf = BytesIO()
|
||||
Image.new("L", (2, 2), color=0).save(buf, format="TIFF")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
_TIFF_BYTES = _make_tiff_bytes()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixtures_dir(tmp_path: Path) -> Path:
|
||||
d = tmp_path / "fixtures"
|
||||
d.mkdir()
|
||||
(d / "sample.pdf").write_bytes(_PDF_BYTES)
|
||||
(d / "sample.png").write_bytes(_PNG_BYTES)
|
||||
(d / "sample.jpg").write_bytes(_JPEG_BYTES)
|
||||
(d / "sample.tif").write_bytes(_TIFF_BYTES)
|
||||
(d / "sample.txt").write_bytes(b"this is plain text, no magic bytes\n")
|
||||
return d
|
||||
|
||||
|
||||
class TestDetectMime:
|
||||
def test_pdf(self, fixtures_dir: Path) -> None:
|
||||
assert detect_mime(fixtures_dir / "sample.pdf") == "application/pdf"
|
||||
|
||||
def test_png(self, fixtures_dir: Path) -> None:
|
||||
assert detect_mime(fixtures_dir / "sample.png") == "image/png"
|
||||
|
||||
def test_jpeg(self, fixtures_dir: Path) -> None:
|
||||
assert detect_mime(fixtures_dir / "sample.jpg") == "image/jpeg"
|
||||
|
||||
def test_tiff(self, fixtures_dir: Path) -> None:
|
||||
assert detect_mime(fixtures_dir / "sample.tif") == "image/tiff"
|
||||
|
||||
|
||||
class TestSupportedSet:
|
||||
def test_supported_mimes_contents(self) -> None:
|
||||
assert {
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/tiff",
|
||||
} == set(SUPPORTED_MIMES)
|
||||
|
||||
|
||||
class TestRequireSupported:
|
||||
def test_allows_supported(self) -> None:
|
||||
for m in SUPPORTED_MIMES:
|
||||
require_supported(m) # no raise
|
||||
|
||||
def test_rejects_unsupported(self) -> None:
|
||||
with pytest.raises(IXException) as ei:
|
||||
require_supported("text/plain")
|
||||
assert ei.value.code is IXErrorCode.IX_000_005
|
||||
assert "text/plain" in (ei.value.detail or "")
|
||||
116
tests/unit/test_ingestion_pages.py
Normal file
116
tests/unit/test_ingestion_pages.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""Tests for DocumentIngestor.build_pages (spec §6.1)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.ingestion import DocumentIngestor
|
||||
|
||||
|
||||
def _make_pdf_bytes(n_pages: int) -> bytes:
|
||||
import fitz
|
||||
|
||||
doc = fitz.open()
|
||||
for i in range(n_pages):
|
||||
page = doc.new_page(width=200, height=300)
|
||||
page.insert_text((10, 20), f"page {i+1}")
|
||||
out = doc.tobytes()
|
||||
doc.close()
|
||||
return out
|
||||
|
||||
|
||||
def _make_multi_frame_tiff_bytes(n_frames: int) -> bytes:
|
||||
from PIL import Image
|
||||
|
||||
frames = [Image.new("L", (10, 10), color=i * 30) for i in range(n_frames)]
|
||||
buf = BytesIO()
|
||||
frames[0].save(buf, format="TIFF", save_all=True, append_images=frames[1:])
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
class TestPdf:
|
||||
def test_three_page_pdf_yields_three_pages(self, tmp_path: Path) -> None:
|
||||
p = tmp_path / "doc.pdf"
|
||||
p.write_bytes(_make_pdf_bytes(3))
|
||||
|
||||
ing = DocumentIngestor()
|
||||
pages, metas = ing.build_pages(files=[(p, "application/pdf")], texts=[])
|
||||
|
||||
assert len(pages) == 3
|
||||
for i, page in enumerate(pages, start=1):
|
||||
assert page.page_no == i
|
||||
assert page.width > 0
|
||||
assert page.height > 0
|
||||
assert len(metas) == 3
|
||||
for m in metas:
|
||||
assert m.file_index == 0
|
||||
|
||||
def test_page_count_cap_raises_IX_000_006(self, tmp_path: Path) -> None:
|
||||
p = tmp_path / "toomany.pdf"
|
||||
p.write_bytes(_make_pdf_bytes(101))
|
||||
ing = DocumentIngestor()
|
||||
with pytest.raises(IXException) as ei:
|
||||
ing.build_pages(files=[(p, "application/pdf")], texts=[])
|
||||
assert ei.value.code is IXErrorCode.IX_000_006
|
||||
|
||||
|
||||
class TestImages:
|
||||
def test_single_frame_png(self, tmp_path: Path) -> None:
|
||||
from PIL import Image
|
||||
|
||||
p = tmp_path / "img.png"
|
||||
Image.new("RGB", (50, 80), color="white").save(p, format="PNG")
|
||||
ing = DocumentIngestor()
|
||||
pages, metas = ing.build_pages(files=[(p, "image/png")], texts=[])
|
||||
assert len(pages) == 1
|
||||
assert pages[0].width == 50
|
||||
assert pages[0].height == 80
|
||||
assert metas[0].file_index == 0
|
||||
|
||||
def test_multi_frame_tiff_yields_multiple_pages(self, tmp_path: Path) -> None:
|
||||
p = tmp_path / "multi.tif"
|
||||
p.write_bytes(_make_multi_frame_tiff_bytes(2))
|
||||
|
||||
ing = DocumentIngestor()
|
||||
pages, metas = ing.build_pages(files=[(p, "image/tiff")], texts=[])
|
||||
assert len(pages) == 2
|
||||
for page in pages:
|
||||
assert page.width == 10
|
||||
assert page.height == 10
|
||||
# Both frames share the same file_index.
|
||||
assert {m.file_index for m in metas} == {0}
|
||||
|
||||
|
||||
class TestTexts:
|
||||
def test_texts_become_pages(self) -> None:
|
||||
ing = DocumentIngestor()
|
||||
pages, metas = ing.build_pages(files=[], texts=["hello", "world"])
|
||||
assert len(pages) == 2
|
||||
assert pages[0].page_no == 1
|
||||
assert pages[1].page_no == 2
|
||||
# Text-backed pages have no file_index source.
|
||||
assert metas[0].file_index is None
|
||||
assert metas[1].file_index is None
|
||||
|
||||
|
||||
class TestFileIndexes:
|
||||
def test_multi_file_indexes_are_contiguous(self, tmp_path: Path) -> None:
|
||||
p1 = tmp_path / "a.pdf"
|
||||
p1.write_bytes(_make_pdf_bytes(2))
|
||||
p2 = tmp_path / "b.pdf"
|
||||
p2.write_bytes(_make_pdf_bytes(1))
|
||||
|
||||
ing = DocumentIngestor()
|
||||
pages, metas = ing.build_pages(
|
||||
files=[(p1, "application/pdf"), (p2, "application/pdf")],
|
||||
texts=[],
|
||||
)
|
||||
assert len(pages) == 3
|
||||
# First two pages from file 0, last from file 1.
|
||||
assert metas[0].file_index == 0
|
||||
assert metas[1].file_index == 0
|
||||
assert metas[2].file_index == 1
|
||||
57
tests/unit/test_ocr_fake.py
Normal file
57
tests/unit/test_ocr_fake.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
"""Tests for OCRClient Protocol + FakeOCRClient (spec §6.2)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.contracts import Line, OCRDetails, OCRResult, Page
|
||||
from ix.ocr import FakeOCRClient, OCRClient
|
||||
|
||||
|
||||
def _canned() -> OCRResult:
|
||||
return OCRResult(
|
||||
result=OCRDetails(
|
||||
text="hello world",
|
||||
pages=[
|
||||
Page(
|
||||
page_no=1,
|
||||
width=100.0,
|
||||
height=200.0,
|
||||
lines=[Line(text="hello world", bounding_box=[0, 0, 10, 0, 10, 5, 0, 5])],
|
||||
)
|
||||
],
|
||||
),
|
||||
meta_data={"engine": "fake"},
|
||||
)
|
||||
|
||||
|
||||
class TestProtocolConformance:
|
||||
def test_fake_is_runtime_checkable_as_protocol(self) -> None:
|
||||
client = FakeOCRClient(canned=_canned())
|
||||
assert isinstance(client, OCRClient)
|
||||
|
||||
|
||||
class TestReturnsCannedResult:
|
||||
async def test_returns_exact_canned_result(self) -> None:
|
||||
canned = _canned()
|
||||
client = FakeOCRClient(canned=canned)
|
||||
result = await client.ocr(pages=[])
|
||||
assert result is canned
|
||||
assert result.result.text == "hello world"
|
||||
assert result.meta_data == {"engine": "fake"}
|
||||
|
||||
async def test_pages_argument_is_accepted_but_ignored(self) -> None:
|
||||
canned = _canned()
|
||||
client = FakeOCRClient(canned=canned)
|
||||
result = await client.ocr(
|
||||
pages=[Page(page_no=5, width=1.0, height=1.0, lines=[])]
|
||||
)
|
||||
assert result is canned
|
||||
|
||||
|
||||
class TestRaiseOnCallHook:
|
||||
async def test_raise_on_call_propagates(self) -> None:
|
||||
err = RuntimeError("surya is down")
|
||||
client = FakeOCRClient(canned=_canned(), raise_on_call=err)
|
||||
with pytest.raises(RuntimeError, match="surya is down"):
|
||||
await client.ocr(pages=[])
|
||||
199
tests/unit/test_ocr_step.py
Normal file
199
tests/unit/test_ocr_step.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
"""Tests for :class:`ix.pipeline.ocr_step.OCRStep` (spec §6.2)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.contracts import (
|
||||
Context,
|
||||
Line,
|
||||
OCRDetails,
|
||||
OCROptions,
|
||||
OCRResult,
|
||||
Options,
|
||||
Page,
|
||||
ProvenanceOptions,
|
||||
RequestIX,
|
||||
ResponseIX,
|
||||
)
|
||||
from ix.contracts.response import _InternalContext
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.ocr import FakeOCRClient
|
||||
from ix.pipeline.ocr_step import OCRStep
|
||||
from ix.segmentation import PageMetadata, SegmentIndex
|
||||
|
||||
|
||||
def _make_request(
|
||||
*,
|
||||
use_ocr: bool = True,
|
||||
include_geometries: bool = False,
|
||||
include_ocr_text: bool = False,
|
||||
ocr_only: bool = False,
|
||||
include_provenance: bool = True,
|
||||
files: list | None = None,
|
||||
texts: list[str] | None = None,
|
||||
) -> RequestIX:
|
||||
return RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="test",
|
||||
request_id="r-1",
|
||||
context=Context(files=files if files is not None else [], texts=texts or []),
|
||||
options=Options(
|
||||
ocr=OCROptions(
|
||||
use_ocr=use_ocr,
|
||||
include_geometries=include_geometries,
|
||||
include_ocr_text=include_ocr_text,
|
||||
ocr_only=ocr_only,
|
||||
),
|
||||
provenance=ProvenanceOptions(include_provenance=include_provenance),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _response_with_context(
|
||||
*,
|
||||
pages: list[Page] | None = None,
|
||||
files: list | None = None,
|
||||
texts: list[str] | None = None,
|
||||
page_metadata: list[PageMetadata] | None = None,
|
||||
) -> ResponseIX:
|
||||
resp = ResponseIX()
|
||||
resp.context = _InternalContext(
|
||||
pages=pages or [],
|
||||
files=files or [],
|
||||
texts=texts or [],
|
||||
page_metadata=page_metadata or [],
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
def _canned_ocr(pages: int = 1) -> OCRResult:
|
||||
return OCRResult(
|
||||
result=OCRDetails(
|
||||
text="\n".join(f"text p{i+1}" for i in range(pages)),
|
||||
pages=[
|
||||
Page(
|
||||
page_no=i + 1,
|
||||
width=100.0,
|
||||
height=200.0,
|
||||
lines=[
|
||||
Line(
|
||||
text=f"line-content p{i+1}",
|
||||
bounding_box=[0, 0, 10, 0, 10, 5, 0, 5],
|
||||
)
|
||||
],
|
||||
)
|
||||
for i in range(pages)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestValidate:
|
||||
async def test_ocr_only_without_files_raises_IX_000_004(self) -> None:
|
||||
step = OCRStep(ocr_client=FakeOCRClient(canned=_canned_ocr()))
|
||||
req = _make_request(ocr_only=True, files=[], texts=["hi"])
|
||||
resp = _response_with_context(files=[])
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.validate(req, resp)
|
||||
assert ei.value.code is IXErrorCode.IX_000_004
|
||||
|
||||
async def test_include_ocr_text_without_files_raises_IX_000_004(self) -> None:
|
||||
step = OCRStep(ocr_client=FakeOCRClient(canned=_canned_ocr()))
|
||||
req = _make_request(include_ocr_text=True, files=[], texts=["hi"])
|
||||
resp = _response_with_context(files=[])
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.validate(req, resp)
|
||||
assert ei.value.code is IXErrorCode.IX_000_004
|
||||
|
||||
async def test_include_geometries_without_files_raises_IX_000_004(self) -> None:
|
||||
step = OCRStep(ocr_client=FakeOCRClient(canned=_canned_ocr()))
|
||||
req = _make_request(include_geometries=True, files=[], texts=["hi"])
|
||||
resp = _response_with_context(files=[])
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.validate(req, resp)
|
||||
assert ei.value.code is IXErrorCode.IX_000_004
|
||||
|
||||
async def test_text_only_skips_step(self) -> None:
|
||||
step = OCRStep(ocr_client=FakeOCRClient(canned=_canned_ocr()))
|
||||
req = _make_request(use_ocr=True, files=[], texts=["hi"])
|
||||
resp = _response_with_context(files=[], texts=["hi"])
|
||||
assert await step.validate(req, resp) is False
|
||||
|
||||
async def test_ocr_runs_when_files_and_use_ocr(self) -> None:
|
||||
step = OCRStep(ocr_client=FakeOCRClient(canned=_canned_ocr()))
|
||||
req = _make_request(use_ocr=True, files=["http://x"])
|
||||
resp = _response_with_context(files=[("/tmp/x.pdf", "application/pdf")])
|
||||
assert await step.validate(req, resp) is True
|
||||
|
||||
|
||||
class TestProcess:
|
||||
async def test_ocr_result_written_to_response(self) -> None:
|
||||
canned = _canned_ocr(pages=2)
|
||||
step = OCRStep(ocr_client=FakeOCRClient(canned=canned))
|
||||
req = _make_request(use_ocr=True, files=["http://x"])
|
||||
resp = _response_with_context(
|
||||
pages=[Page(page_no=1, width=100.0, height=200.0, lines=[])] * 2,
|
||||
files=[("/tmp/x.pdf", "application/pdf")],
|
||||
page_metadata=[PageMetadata(file_index=0), PageMetadata(file_index=0)],
|
||||
)
|
||||
resp = await step.process(req, resp)
|
||||
# Full OCR result written.
|
||||
assert resp.ocr_result.result.text == canned.result.text
|
||||
# Page tags injected: prepend + append around lines per page.
|
||||
pages = resp.ocr_result.result.pages
|
||||
assert len(pages) == 2
|
||||
# First line is the opening <page> tag.
|
||||
assert pages[0].lines[0].text is not None
|
||||
assert pages[0].lines[0].text.startswith("<page ")
|
||||
# Last line is the closing </page> tag.
|
||||
assert pages[0].lines[-1].text == "</page>"
|
||||
|
||||
async def test_segment_index_built_when_provenance_on(self) -> None:
|
||||
canned = _canned_ocr(pages=1)
|
||||
step = OCRStep(ocr_client=FakeOCRClient(canned=canned))
|
||||
req = _make_request(
|
||||
use_ocr=True, include_provenance=True, files=["http://x"]
|
||||
)
|
||||
resp = _response_with_context(
|
||||
pages=[Page(page_no=1, width=100.0, height=200.0, lines=[])],
|
||||
files=[("/tmp/x.pdf", "application/pdf")],
|
||||
page_metadata=[PageMetadata(file_index=0)],
|
||||
)
|
||||
resp = await step.process(req, resp)
|
||||
seg_idx = resp.context.segment_index # type: ignore[union-attr]
|
||||
assert isinstance(seg_idx, SegmentIndex)
|
||||
# Page-tag lines are excluded; only the real line becomes a segment.
|
||||
assert seg_idx._ordered_ids == ["p1_l0"]
|
||||
pos = seg_idx.lookup_segment("p1_l0")
|
||||
assert pos is not None
|
||||
assert pos["text"] == "line-content p1"
|
||||
|
||||
async def test_segment_index_not_built_when_provenance_off(self) -> None:
|
||||
canned = _canned_ocr(pages=1)
|
||||
step = OCRStep(ocr_client=FakeOCRClient(canned=canned))
|
||||
req = _make_request(
|
||||
use_ocr=True, include_provenance=False, files=["http://x"]
|
||||
)
|
||||
resp = _response_with_context(
|
||||
pages=[Page(page_no=1, width=100.0, height=200.0, lines=[])],
|
||||
files=[("/tmp/x.pdf", "application/pdf")],
|
||||
page_metadata=[PageMetadata(file_index=0)],
|
||||
)
|
||||
resp = await step.process(req, resp)
|
||||
assert resp.context.segment_index is None # type: ignore[union-attr]
|
||||
|
||||
async def test_page_tags_include_file_index(self) -> None:
|
||||
canned = _canned_ocr(pages=1)
|
||||
step = OCRStep(ocr_client=FakeOCRClient(canned=canned))
|
||||
req = _make_request(use_ocr=True, files=["http://x"])
|
||||
resp = _response_with_context(
|
||||
pages=[Page(page_no=1, width=100.0, height=200.0, lines=[])],
|
||||
files=[("/tmp/x.pdf", "application/pdf")],
|
||||
page_metadata=[PageMetadata(file_index=3)],
|
||||
)
|
||||
resp = await step.process(req, resp)
|
||||
first_line = resp.ocr_result.result.pages[0].lines[0].text
|
||||
assert first_line is not None
|
||||
assert 'file="3"' in first_line
|
||||
assert 'number="1"' in first_line
|
||||
270
tests/unit/test_ollama_client.py
Normal file
270
tests/unit/test_ollama_client.py
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
"""Tests for :class:`OllamaClient` — hermetic, pytest-httpx-driven.
|
||||
|
||||
Covers spec §6 GenAIStep Ollama call contract:
|
||||
|
||||
* POST body shape (model / messages / format / stream / options).
|
||||
* Response parsing → :class:`GenAIInvocationResult`.
|
||||
* Error mapping: connection / timeout / 5xx → ``IX_002_000``;
|
||||
schema-violating body → ``IX_002_001``.
|
||||
* ``selfcheck()``: tags-reachable + model-listed → ``ok``;
|
||||
reachable-but-missing → ``degraded``; unreachable → ``fail``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pytest_httpx import HTTPXMock
|
||||
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.genai.ollama_client import OllamaClient
|
||||
|
||||
|
||||
class _Schema(BaseModel):
|
||||
"""Trivial structured-output schema for the round-trip tests."""
|
||||
|
||||
bank_name: str
|
||||
account_number: str | None = None
|
||||
|
||||
|
||||
def _ollama_chat_ok_body(content_json: str) -> dict:
|
||||
"""Build a minimal Ollama /api/chat success body."""
|
||||
return {
|
||||
"model": "gpt-oss:20b",
|
||||
"message": {"role": "assistant", "content": content_json},
|
||||
"done": True,
|
||||
"eval_count": 42,
|
||||
"prompt_eval_count": 17,
|
||||
}
|
||||
|
||||
|
||||
class TestInvokeHappyPath:
|
||||
async def test_posts_to_chat_endpoint_with_format_and_no_stream(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/chat",
|
||||
method="POST",
|
||||
json=_ollama_chat_ok_body('{"bank_name":"DKB","account_number":"DE89"}'),
|
||||
)
|
||||
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
result = await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You extract."},
|
||||
{"role": "user", "content": "Doc body"},
|
||||
],
|
||||
"temperature": 0.2,
|
||||
"reasoning_effort": "high", # dropped silently
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
|
||||
assert result.parsed == _Schema(bank_name="DKB", account_number="DE89")
|
||||
assert result.model_name == "gpt-oss:20b"
|
||||
assert result.usage.prompt_tokens == 17
|
||||
assert result.usage.completion_tokens == 42
|
||||
|
||||
# Verify request shape.
|
||||
requests = httpx_mock.get_requests()
|
||||
assert len(requests) == 1
|
||||
body = requests[0].read().decode()
|
||||
import json
|
||||
|
||||
body_json = json.loads(body)
|
||||
assert body_json["model"] == "gpt-oss:20b"
|
||||
assert body_json["stream"] is False
|
||||
# No `format` is sent: Ollama 0.11.8 segfaults on full schemas and
|
||||
# aborts to `{}` with `format=json` on reasoning models. Schema is
|
||||
# injected into the system prompt instead; we extract the trailing
|
||||
# JSON blob from the response and validate via Pydantic.
|
||||
assert "format" not in body_json
|
||||
assert body_json["options"]["temperature"] == 0.2
|
||||
assert "reasoning_effort" not in body_json
|
||||
# A schema-guidance system message is prepended to the caller's
|
||||
# messages so Ollama (format=json loose mode) emits the right shape.
|
||||
msgs = body_json["messages"]
|
||||
assert msgs[0]["role"] == "system"
|
||||
assert "JSON Schema" in msgs[0]["content"]
|
||||
assert msgs[1:] == [
|
||||
{"role": "system", "content": "You extract."},
|
||||
{"role": "user", "content": "Doc body"},
|
||||
]
|
||||
|
||||
async def test_text_parts_content_list_is_joined(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/chat",
|
||||
method="POST",
|
||||
json=_ollama_chat_ok_body('{"bank_name":"X"}'),
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "part-a"},
|
||||
{"type": "text", "text": "part-b"},
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
import json
|
||||
|
||||
request_body = json.loads(httpx_mock.get_requests()[0].read())
|
||||
# First message is the auto-injected schema guidance; after that
|
||||
# the caller's user message has its text parts joined.
|
||||
assert request_body["messages"][0]["role"] == "system"
|
||||
assert request_body["messages"][1:] == [
|
||||
{"role": "user", "content": "part-a\npart-b"}
|
||||
]
|
||||
|
||||
|
||||
class TestInvokeErrorPaths:
|
||||
async def test_connection_error_maps_to_002_000(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_exception(httpx.ConnectError("refused"))
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=1.0
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
assert ei.value.code is IXErrorCode.IX_002_000
|
||||
|
||||
async def test_read_timeout_maps_to_002_000(self, httpx_mock: HTTPXMock) -> None:
|
||||
httpx_mock.add_exception(httpx.ReadTimeout("slow"))
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=0.5
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
assert ei.value.code is IXErrorCode.IX_002_000
|
||||
|
||||
async def test_500_maps_to_002_000_with_body_snippet(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/chat",
|
||||
method="POST",
|
||||
status_code=500,
|
||||
text="boom boom server broken",
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
assert ei.value.code is IXErrorCode.IX_002_000
|
||||
assert "boom" in (ei.value.detail or "")
|
||||
|
||||
async def test_200_with_invalid_json_maps_to_002_001(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/chat",
|
||||
method="POST",
|
||||
json=_ollama_chat_ok_body("not-json"),
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
assert ei.value.code is IXErrorCode.IX_002_001
|
||||
|
||||
async def test_200_with_schema_violation_maps_to_002_001(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
# Missing required `bank_name` field.
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/chat",
|
||||
method="POST",
|
||||
json=_ollama_chat_ok_body('{"account_number":"DE89"}'),
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
assert ei.value.code is IXErrorCode.IX_002_001
|
||||
|
||||
|
||||
class TestSelfcheck:
|
||||
async def test_selfcheck_ok_when_model_listed(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/tags",
|
||||
method="GET",
|
||||
json={"models": [{"name": "gpt-oss:20b"}, {"name": "qwen2.5:32b"}]},
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
assert await client.selfcheck(expected_model="gpt-oss:20b") == "ok"
|
||||
|
||||
async def test_selfcheck_degraded_when_model_missing(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/tags",
|
||||
method="GET",
|
||||
json={"models": [{"name": "qwen2.5:32b"}]},
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
assert await client.selfcheck(expected_model="gpt-oss:20b") == "degraded"
|
||||
|
||||
async def test_selfcheck_fail_on_connection_error(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_exception(httpx.ConnectError("refused"))
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
assert await client.selfcheck(expected_model="gpt-oss:20b") == "fail"
|
||||
208
tests/unit/test_pipeline.py
Normal file
208
tests/unit/test_pipeline.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
"""Tests for the Pipeline orchestrator + Step ABC + Timer (spec §3, §4)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.contracts import Context, RequestIX, ResponseIX
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.pipeline import Pipeline, Step
|
||||
|
||||
|
||||
def _make_request() -> RequestIX:
|
||||
return RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="test",
|
||||
request_id="r-1",
|
||||
context=Context(files=[], texts=["hello"]),
|
||||
)
|
||||
|
||||
|
||||
class StubStep(Step):
|
||||
"""Hand-written Step double. Records call order on a shared list."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
calls: list[str] | None = None,
|
||||
validate_returns: bool = True,
|
||||
validate_raises: IXException | None = None,
|
||||
process_raises: IXException | BaseException | None = None,
|
||||
mutate: Any = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self._calls = calls if calls is not None else []
|
||||
self._validate_returns = validate_returns
|
||||
self._validate_raises = validate_raises
|
||||
self._process_raises = process_raises
|
||||
self._mutate = mutate
|
||||
|
||||
@property
|
||||
def step_name(self) -> str: # type: ignore[override]
|
||||
return self.name
|
||||
|
||||
async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool:
|
||||
self._calls.append(f"{self.name}.validate")
|
||||
if self._validate_raises is not None:
|
||||
raise self._validate_raises
|
||||
return self._validate_returns
|
||||
|
||||
async def process(
|
||||
self, request_ix: RequestIX, response_ix: ResponseIX
|
||||
) -> ResponseIX:
|
||||
self._calls.append(f"{self.name}.process")
|
||||
if self._process_raises is not None:
|
||||
raise self._process_raises
|
||||
if self._mutate is not None:
|
||||
self._mutate(response_ix)
|
||||
return response_ix
|
||||
|
||||
|
||||
class TestOrdering:
|
||||
async def test_steps_run_in_registered_order(self) -> None:
|
||||
calls: list[str] = []
|
||||
pipeline = Pipeline(
|
||||
steps=[
|
||||
StubStep("a", calls=calls),
|
||||
StubStep("b", calls=calls),
|
||||
StubStep("c", calls=calls),
|
||||
]
|
||||
)
|
||||
await pipeline.start(_make_request())
|
||||
assert calls == [
|
||||
"a.validate",
|
||||
"a.process",
|
||||
"b.validate",
|
||||
"b.process",
|
||||
"c.validate",
|
||||
"c.process",
|
||||
]
|
||||
|
||||
|
||||
class TestSkipOnFalse:
|
||||
async def test_validate_false_skips_process(self) -> None:
|
||||
calls: list[str] = []
|
||||
pipeline = Pipeline(
|
||||
steps=[
|
||||
StubStep("a", calls=calls, validate_returns=False),
|
||||
StubStep("b", calls=calls),
|
||||
]
|
||||
)
|
||||
response = await pipeline.start(_make_request())
|
||||
assert calls == ["a.validate", "b.validate", "b.process"]
|
||||
assert response.error is None
|
||||
|
||||
|
||||
class TestErrorFromValidate:
|
||||
async def test_validate_raising_ix_sets_error_and_aborts(self) -> None:
|
||||
calls: list[str] = []
|
||||
err = IXException(IXErrorCode.IX_000_002, detail="nothing")
|
||||
pipeline = Pipeline(
|
||||
steps=[
|
||||
StubStep("a", calls=calls, validate_raises=err),
|
||||
StubStep("b", calls=calls),
|
||||
]
|
||||
)
|
||||
response = await pipeline.start(_make_request())
|
||||
assert calls == ["a.validate"]
|
||||
assert response.error is not None
|
||||
assert response.error.startswith("IX_000_002")
|
||||
assert "nothing" in response.error
|
||||
|
||||
|
||||
class TestErrorFromProcess:
|
||||
async def test_process_raising_ix_sets_error_and_aborts(self) -> None:
|
||||
calls: list[str] = []
|
||||
err = IXException(IXErrorCode.IX_002_001, detail="bad parse")
|
||||
pipeline = Pipeline(
|
||||
steps=[
|
||||
StubStep("a", calls=calls, process_raises=err),
|
||||
StubStep("b", calls=calls),
|
||||
]
|
||||
)
|
||||
response = await pipeline.start(_make_request())
|
||||
assert calls == ["a.validate", "a.process"]
|
||||
assert response.error is not None
|
||||
assert response.error.startswith("IX_002_001")
|
||||
assert "bad parse" in response.error
|
||||
|
||||
|
||||
class TestTimings:
|
||||
async def test_timings_populated_for_every_executed_step(self) -> None:
|
||||
calls: list[str] = []
|
||||
pipeline = Pipeline(
|
||||
steps=[
|
||||
StubStep("alpha", calls=calls),
|
||||
StubStep("beta", calls=calls),
|
||||
]
|
||||
)
|
||||
response = await pipeline.start(_make_request())
|
||||
# One timing per executed step.
|
||||
names = [t["step"] for t in response.metadata.timings]
|
||||
assert names == ["alpha", "beta"]
|
||||
for entry in response.metadata.timings:
|
||||
assert isinstance(entry["elapsed_seconds"], float)
|
||||
assert entry["elapsed_seconds"] >= 0.0
|
||||
|
||||
async def test_timing_recorded_even_when_validate_false(self) -> None:
|
||||
calls: list[str] = []
|
||||
pipeline = Pipeline(
|
||||
steps=[
|
||||
StubStep("skipper", calls=calls, validate_returns=False),
|
||||
StubStep("runner", calls=calls),
|
||||
]
|
||||
)
|
||||
response = await pipeline.start(_make_request())
|
||||
names = [t["step"] for t in response.metadata.timings]
|
||||
assert names == ["skipper", "runner"]
|
||||
|
||||
async def test_timing_recorded_when_validate_raises(self) -> None:
|
||||
calls: list[str] = []
|
||||
err = IXException(IXErrorCode.IX_001_001, detail="missing")
|
||||
pipeline = Pipeline(
|
||||
steps=[
|
||||
StubStep("boom", calls=calls, validate_raises=err),
|
||||
StubStep("unreached", calls=calls),
|
||||
]
|
||||
)
|
||||
response = await pipeline.start(_make_request())
|
||||
names = [t["step"] for t in response.metadata.timings]
|
||||
assert names == ["boom"]
|
||||
|
||||
|
||||
class TestNonIXExceptionStillSurfaces:
|
||||
async def test_generic_exception_in_process_aborts(self) -> None:
|
||||
calls: list[str] = []
|
||||
pipeline = Pipeline(
|
||||
steps=[
|
||||
StubStep("a", calls=calls, process_raises=RuntimeError("kaboom")),
|
||||
StubStep("b", calls=calls),
|
||||
]
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
await pipeline.start(_make_request())
|
||||
assert calls == ["a.validate", "a.process"]
|
||||
|
||||
|
||||
class TestStepMutation:
|
||||
async def test_response_ix_shared_across_steps(self) -> None:
|
||||
def mutate_a(resp: ResponseIX) -> None:
|
||||
resp.use_case = "set-by-a"
|
||||
|
||||
def mutate_b(resp: ResponseIX) -> None:
|
||||
# b sees what a wrote.
|
||||
assert resp.use_case == "set-by-a"
|
||||
resp.use_case_name = "set-by-b"
|
||||
|
||||
pipeline = Pipeline(
|
||||
steps=[
|
||||
StubStep("a", mutate=mutate_a),
|
||||
StubStep("b", mutate=mutate_b),
|
||||
]
|
||||
)
|
||||
response = await pipeline.start(_make_request())
|
||||
assert response.use_case == "set-by-a"
|
||||
assert response.use_case_name == "set-by-b"
|
||||
272
tests/unit/test_pipeline_end_to_end.py
Normal file
272
tests/unit/test_pipeline_end_to_end.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
"""End-to-end pipeline test with the fake OCR + GenAI clients (spec sections 6-9).
|
||||
|
||||
Feeds the committed ``tests/fixtures/synthetic_giro.pdf`` through the
|
||||
full five-step pipeline with canned OCR + canned LLM responses.
|
||||
Hermetic: no Surya, no Ollama, no network.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ix.contracts import (
|
||||
Context,
|
||||
Line,
|
||||
OCRDetails,
|
||||
OCROptions,
|
||||
OCRResult,
|
||||
Options,
|
||||
Page,
|
||||
ProvenanceOptions,
|
||||
RequestIX,
|
||||
SegmentCitation,
|
||||
)
|
||||
from ix.genai import FakeGenAIClient, GenAIUsage
|
||||
from ix.ocr import FakeOCRClient
|
||||
from ix.pipeline import Pipeline
|
||||
from ix.pipeline.genai_step import GenAIStep
|
||||
from ix.pipeline.ocr_step import OCRStep
|
||||
from ix.pipeline.reliability_step import ReliabilityStep
|
||||
from ix.pipeline.response_handler_step import ResponseHandlerStep
|
||||
from ix.pipeline.setup_step import SetupStep
|
||||
from ix.use_cases.bank_statement_header import BankStatementHeader
|
||||
|
||||
FIXTURE_PDF = Path(__file__).resolve().parent.parent / "fixtures" / "synthetic_giro.pdf"
|
||||
|
||||
|
||||
# Ground-truth values. Must match the strings the fixture builder drops on
|
||||
# the page AND the canned OCR output below.
|
||||
EXPECTED_BANK_NAME = "DKB"
|
||||
EXPECTED_IBAN = "DE89370400440532013000"
|
||||
EXPECTED_OPENING = Decimal("1234.56")
|
||||
EXPECTED_CLOSING = Decimal("1450.22")
|
||||
EXPECTED_CURRENCY = "EUR"
|
||||
EXPECTED_STATEMENT_DATE = date(2026, 3, 31)
|
||||
EXPECTED_PERIOD_START = date(2026, 3, 1)
|
||||
EXPECTED_PERIOD_END = date(2026, 3, 31)
|
||||
|
||||
|
||||
def _canned_ocr_result() -> OCRResult:
|
||||
"""Canned Surya-shaped result for the synthetic_giro fixture.
|
||||
|
||||
Line texts match the strings placed by create_fixture_pdf.py. Bboxes
|
||||
are plausible-but-not-exact: the fixture builder uses 72 pt left
|
||||
margin and 24 pt line height on a 595x842 page, so we mirror those
|
||||
coords here so normalisation gives sensible 0-1 values.
|
||||
"""
|
||||
width, height = 595.0, 842.0
|
||||
lines_meta = [
|
||||
("DKB", 60.0),
|
||||
("IBAN: DE89370400440532013000", 84.0),
|
||||
("Statement period: 01.03.2026 - 31.03.2026", 108.0),
|
||||
("Opening balance: 1234.56 EUR", 132.0),
|
||||
("Closing balance: 1450.22 EUR", 156.0),
|
||||
("Statement date: 31.03.2026", 180.0),
|
||||
]
|
||||
lines: list[Line] = []
|
||||
for text, y_top in lines_meta:
|
||||
y_bot = y_top + 16.0
|
||||
lines.append(
|
||||
Line(
|
||||
text=text,
|
||||
bounding_box=[72.0, y_top, 500.0, y_top, 500.0, y_bot, 72.0, y_bot],
|
||||
)
|
||||
)
|
||||
return OCRResult(
|
||||
result=OCRDetails(
|
||||
text="\n".join(t for t, _ in lines_meta),
|
||||
pages=[
|
||||
Page(
|
||||
page_no=1,
|
||||
width=width,
|
||||
height=height,
|
||||
lines=lines,
|
||||
)
|
||||
],
|
||||
),
|
||||
meta_data={"engine": "fake"},
|
||||
)
|
||||
|
||||
|
||||
class _WrappedResponse(BaseModel):
|
||||
"""Mirrors the runtime ProvenanceWrappedResponse GenAIStep creates."""
|
||||
|
||||
result: BankStatementHeader
|
||||
segment_citations: list[SegmentCitation] = []
|
||||
|
||||
|
||||
def _canned_llm_output() -> _WrappedResponse:
|
||||
# After OCRStep injects <page> tag lines, the real OCR line at local
|
||||
# index 0 gets segment id p1_l0 (tag lines are skipped by
|
||||
# SegmentIndex.build). So:
|
||||
# p1_l0 -> "DKB"
|
||||
# p1_l1 -> "IBAN: DE89370400440532013000"
|
||||
# p1_l2 -> "Statement period: 01.03.2026 - 31.03.2026"
|
||||
# p1_l3 -> "Opening balance: 1234.56 EUR"
|
||||
# p1_l4 -> "Closing balance: 1450.22 EUR"
|
||||
# p1_l5 -> "Statement date: 31.03.2026"
|
||||
return _WrappedResponse(
|
||||
result=BankStatementHeader(
|
||||
bank_name=EXPECTED_BANK_NAME,
|
||||
account_iban=EXPECTED_IBAN,
|
||||
account_type="checking",
|
||||
currency=EXPECTED_CURRENCY,
|
||||
statement_date=EXPECTED_STATEMENT_DATE,
|
||||
statement_period_start=EXPECTED_PERIOD_START,
|
||||
statement_period_end=EXPECTED_PERIOD_END,
|
||||
opening_balance=EXPECTED_OPENING,
|
||||
closing_balance=EXPECTED_CLOSING,
|
||||
),
|
||||
segment_citations=[
|
||||
SegmentCitation(
|
||||
field_path="result.bank_name",
|
||||
value_segment_ids=["p1_l0"],
|
||||
context_segment_ids=[],
|
||||
),
|
||||
SegmentCitation(
|
||||
field_path="result.account_iban",
|
||||
value_segment_ids=["p1_l1"],
|
||||
context_segment_ids=[],
|
||||
),
|
||||
SegmentCitation(
|
||||
field_path="result.account_type",
|
||||
value_segment_ids=[],
|
||||
context_segment_ids=["p1_l0"],
|
||||
),
|
||||
SegmentCitation(
|
||||
field_path="result.currency",
|
||||
value_segment_ids=["p1_l3", "p1_l4"],
|
||||
context_segment_ids=[],
|
||||
),
|
||||
SegmentCitation(
|
||||
field_path="result.statement_date",
|
||||
value_segment_ids=["p1_l5"],
|
||||
context_segment_ids=[],
|
||||
),
|
||||
SegmentCitation(
|
||||
field_path="result.statement_period_start",
|
||||
value_segment_ids=["p1_l2"],
|
||||
context_segment_ids=[],
|
||||
),
|
||||
SegmentCitation(
|
||||
field_path="result.statement_period_end",
|
||||
value_segment_ids=["p1_l2"],
|
||||
context_segment_ids=[],
|
||||
),
|
||||
SegmentCitation(
|
||||
field_path="result.opening_balance",
|
||||
value_segment_ids=["p1_l3"],
|
||||
context_segment_ids=[],
|
||||
),
|
||||
SegmentCitation(
|
||||
field_path="result.closing_balance",
|
||||
value_segment_ids=["p1_l4"],
|
||||
context_segment_ids=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _build_pipeline(fetch_config: Any = None) -> Pipeline:
|
||||
ocr_client = FakeOCRClient(canned=_canned_ocr_result())
|
||||
genai_client = FakeGenAIClient(
|
||||
parsed=_canned_llm_output(),
|
||||
usage=GenAIUsage(prompt_tokens=200, completion_tokens=400),
|
||||
model_name="fake-gpt",
|
||||
)
|
||||
setup = SetupStep(fetch_config=fetch_config) if fetch_config else SetupStep()
|
||||
return Pipeline(
|
||||
steps=[
|
||||
setup,
|
||||
OCRStep(ocr_client=ocr_client),
|
||||
GenAIStep(genai_client=genai_client),
|
||||
ReliabilityStep(),
|
||||
ResponseHandlerStep(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
@pytest.fixture
|
||||
def request_ix(self, tmp_path: Path) -> RequestIX:
|
||||
# Canonical single-file request pointing to the committed fixture
|
||||
# via file:// URL. Also includes a matching Paperless-style text
|
||||
# so text_agreement has real data to compare against.
|
||||
paperless_text = (
|
||||
"DKB statement. IBAN: DE89370400440532013000. Period 01.03.2026 - "
|
||||
"31.03.2026. Opening balance 1234.56 EUR. Closing balance 1450.22 EUR. "
|
||||
"Date 31.03.2026."
|
||||
)
|
||||
return RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="mammon-test",
|
||||
request_id="end-to-end-1",
|
||||
ix_id="abcd0123ef456789",
|
||||
context=Context(
|
||||
files=[FIXTURE_PDF.as_uri()],
|
||||
texts=[paperless_text],
|
||||
),
|
||||
options=Options(
|
||||
ocr=OCROptions(use_ocr=True),
|
||||
provenance=ProvenanceOptions(
|
||||
include_provenance=True, max_sources_per_field=5
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
async def test_ix_result_populated_from_fake_llm(self, request_ix: RequestIX) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
response = await pipeline.start(request_ix)
|
||||
assert response.error is None
|
||||
result = response.ix_result.result
|
||||
assert result["bank_name"] == EXPECTED_BANK_NAME
|
||||
assert result["account_iban"] == EXPECTED_IBAN
|
||||
assert result["currency"] == EXPECTED_CURRENCY
|
||||
# Pydantic v2 dumps Decimals as strings in mode="json".
|
||||
assert result["closing_balance"] == str(EXPECTED_CLOSING)
|
||||
|
||||
async def test_provenance_verified_for_closing_balance(
|
||||
self, request_ix: RequestIX
|
||||
) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
response = await pipeline.start(request_ix)
|
||||
assert response.provenance is not None
|
||||
fp = response.provenance.fields["result.closing_balance"]
|
||||
assert fp.provenance_verified is True
|
||||
|
||||
async def test_text_agreement_true_when_texts_match_value(
|
||||
self, request_ix: RequestIX
|
||||
) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
response = await pipeline.start(request_ix)
|
||||
assert response.provenance is not None
|
||||
fp = response.provenance.fields["result.closing_balance"]
|
||||
assert fp.text_agreement is True
|
||||
|
||||
async def test_timings_per_step(self, request_ix: RequestIX) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
response = await pipeline.start(request_ix)
|
||||
# Each of the five steps executed and recorded a timing.
|
||||
names = [t["step"] for t in response.metadata.timings]
|
||||
assert names == [
|
||||
"SetupStep",
|
||||
"OCRStep",
|
||||
"GenAIStep",
|
||||
"ReliabilityStep",
|
||||
"ResponseHandlerStep",
|
||||
]
|
||||
for entry in response.metadata.timings:
|
||||
assert isinstance(entry["elapsed_seconds"], float)
|
||||
|
||||
async def test_no_error_and_context_stripped(self, request_ix: RequestIX) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
response = await pipeline.start(request_ix)
|
||||
assert response.error is None
|
||||
dump = response.model_dump()
|
||||
assert "context" not in dump
|
||||
206
tests/unit/test_provenance_mapper.py
Normal file
206
tests/unit/test_provenance_mapper.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
"""Tests for the provenance mapper (spec §9.4)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.contracts import (
|
||||
BoundingBox,
|
||||
Line,
|
||||
OCRDetails,
|
||||
OCRResult,
|
||||
Page,
|
||||
SegmentCitation,
|
||||
)
|
||||
from ix.provenance.mapper import (
|
||||
map_segment_refs_to_provenance,
|
||||
resolve_nested_path,
|
||||
)
|
||||
from ix.segmentation import PageMetadata, SegmentIndex
|
||||
|
||||
|
||||
def _make_index_with_lines(lines: list[tuple[str, int]]) -> SegmentIndex:
|
||||
"""Build a tiny index where each line has a known text + file_index.
|
||||
|
||||
Each entry is (text, file_index); all entries go on a single page.
|
||||
"""
|
||||
ocr_lines = [Line(text=t, bounding_box=[0, 0, 10, 0, 10, 5, 0, 5]) for t, _ in lines]
|
||||
page = Page(page_no=1, width=100.0, height=200.0, lines=ocr_lines)
|
||||
ocr = OCRResult(result=OCRDetails(pages=[page]))
|
||||
# file_index for the whole page — the test uses a single page.
|
||||
file_index = lines[0][1] if lines else 0
|
||||
return SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=[PageMetadata(file_index=file_index)],
|
||||
)
|
||||
|
||||
|
||||
class TestResolveNestedPath:
|
||||
def test_simple_path(self) -> None:
|
||||
assert resolve_nested_path({"result": {"a": "x"}}, "result.a") == "x"
|
||||
|
||||
def test_nested_path(self) -> None:
|
||||
data = {"result": {"header": {"bank": "UBS"}}}
|
||||
assert resolve_nested_path(data, "result.header.bank") == "UBS"
|
||||
|
||||
def test_missing_path_returns_none(self) -> None:
|
||||
assert resolve_nested_path({"result": {}}, "result.nope") is None
|
||||
|
||||
def test_array_bracket_notation_normalised(self) -> None:
|
||||
data = {"result": {"items": [{"name": "a"}, {"name": "b"}]}}
|
||||
assert resolve_nested_path(data, "result.items[0].name") == "a"
|
||||
assert resolve_nested_path(data, "result.items[1].name") == "b"
|
||||
|
||||
def test_array_dot_notation(self) -> None:
|
||||
data = {"result": {"items": [{"name": "a"}, {"name": "b"}]}}
|
||||
assert resolve_nested_path(data, "result.items.0.name") == "a"
|
||||
|
||||
|
||||
class TestMapper:
|
||||
def test_simple_single_field(self) -> None:
|
||||
idx = _make_index_with_lines([("UBS AG", 0), ("Header text", 0)])
|
||||
extraction = {"result": {"bank_name": "UBS AG"}}
|
||||
citations = [
|
||||
SegmentCitation(field_path="result.bank_name", value_segment_ids=["p1_l0"])
|
||||
]
|
||||
|
||||
prov = map_segment_refs_to_provenance(
|
||||
extraction_result=extraction,
|
||||
segment_citations=citations,
|
||||
segment_index=idx,
|
||||
max_sources_per_field=10,
|
||||
min_confidence=0.0,
|
||||
include_bounding_boxes=True,
|
||||
source_type="value_and_context",
|
||||
)
|
||||
|
||||
fp = prov.fields["result.bank_name"]
|
||||
assert fp.field_name == "bank_name"
|
||||
assert fp.value == "UBS AG"
|
||||
assert len(fp.sources) == 1
|
||||
src = fp.sources[0]
|
||||
assert src.segment_id == "p1_l0"
|
||||
assert src.text_snippet == "UBS AG"
|
||||
assert src.page_number == 1
|
||||
assert src.file_index == 0
|
||||
assert isinstance(src.bounding_box, BoundingBox)
|
||||
# quality_metrics populated
|
||||
assert prov.quality_metrics["invalid_references"] == 0
|
||||
|
||||
def test_invalid_reference_counted(self) -> None:
|
||||
idx = _make_index_with_lines([("UBS AG", 0)])
|
||||
extraction = {"result": {"bank_name": "UBS AG"}}
|
||||
citations = [
|
||||
SegmentCitation(
|
||||
field_path="result.bank_name",
|
||||
value_segment_ids=["p1_l0", "p9_l9"], # p9_l9 doesn't exist
|
||||
)
|
||||
]
|
||||
prov = map_segment_refs_to_provenance(
|
||||
extraction_result=extraction,
|
||||
segment_citations=citations,
|
||||
segment_index=idx,
|
||||
max_sources_per_field=10,
|
||||
min_confidence=0.0,
|
||||
include_bounding_boxes=True,
|
||||
source_type="value_and_context",
|
||||
)
|
||||
assert prov.quality_metrics["invalid_references"] == 1
|
||||
# The one valid source still populated.
|
||||
assert len(prov.fields["result.bank_name"].sources) == 1
|
||||
|
||||
def test_max_sources_cap(self) -> None:
|
||||
# Five lines; ask for a cap of 2.
|
||||
idx = _make_index_with_lines([(f"line {i}", 0) for i in range(5)])
|
||||
citations = [
|
||||
SegmentCitation(
|
||||
field_path="result.notes",
|
||||
value_segment_ids=[f"p1_l{i}" for i in range(5)],
|
||||
)
|
||||
]
|
||||
prov = map_segment_refs_to_provenance(
|
||||
extraction_result={"result": {"notes": "noise"}},
|
||||
segment_citations=citations,
|
||||
segment_index=idx,
|
||||
max_sources_per_field=2,
|
||||
min_confidence=0.0,
|
||||
include_bounding_boxes=True,
|
||||
source_type="value_and_context",
|
||||
)
|
||||
assert len(prov.fields["result.notes"].sources) == 2
|
||||
|
||||
def test_source_type_value_only(self) -> None:
|
||||
idx = _make_index_with_lines([("label:", 0), ("UBS AG", 0)])
|
||||
citations = [
|
||||
SegmentCitation(
|
||||
field_path="result.bank_name",
|
||||
value_segment_ids=["p1_l1"],
|
||||
context_segment_ids=["p1_l0"],
|
||||
)
|
||||
]
|
||||
prov = map_segment_refs_to_provenance(
|
||||
extraction_result={"result": {"bank_name": "UBS AG"}},
|
||||
segment_citations=citations,
|
||||
segment_index=idx,
|
||||
max_sources_per_field=10,
|
||||
min_confidence=0.0,
|
||||
include_bounding_boxes=True,
|
||||
source_type="value",
|
||||
)
|
||||
sources = prov.fields["result.bank_name"].sources
|
||||
# Only value_segment_ids included.
|
||||
assert [s.segment_id for s in sources] == ["p1_l1"]
|
||||
|
||||
def test_source_type_value_and_context(self) -> None:
|
||||
idx = _make_index_with_lines([("label:", 0), ("UBS AG", 0)])
|
||||
citations = [
|
||||
SegmentCitation(
|
||||
field_path="result.bank_name",
|
||||
value_segment_ids=["p1_l1"],
|
||||
context_segment_ids=["p1_l0"],
|
||||
)
|
||||
]
|
||||
prov = map_segment_refs_to_provenance(
|
||||
extraction_result={"result": {"bank_name": "UBS AG"}},
|
||||
segment_citations=citations,
|
||||
segment_index=idx,
|
||||
max_sources_per_field=10,
|
||||
min_confidence=0.0,
|
||||
include_bounding_boxes=True,
|
||||
source_type="value_and_context",
|
||||
)
|
||||
sources = prov.fields["result.bank_name"].sources
|
||||
assert [s.segment_id for s in sources] == ["p1_l1", "p1_l0"]
|
||||
|
||||
def test_include_bounding_boxes_false(self) -> None:
|
||||
idx = _make_index_with_lines([("UBS AG", 0)])
|
||||
citations = [
|
||||
SegmentCitation(field_path="result.bank_name", value_segment_ids=["p1_l0"])
|
||||
]
|
||||
prov = map_segment_refs_to_provenance(
|
||||
extraction_result={"result": {"bank_name": "UBS AG"}},
|
||||
segment_citations=citations,
|
||||
segment_index=idx,
|
||||
max_sources_per_field=10,
|
||||
min_confidence=0.0,
|
||||
include_bounding_boxes=False,
|
||||
source_type="value_and_context",
|
||||
)
|
||||
assert prov.fields["result.bank_name"].sources[0].bounding_box is None
|
||||
|
||||
def test_field_with_no_valid_sources_skipped(self) -> None:
|
||||
idx = _make_index_with_lines([("UBS", 0)])
|
||||
citations = [
|
||||
SegmentCitation(field_path="result.ghost", value_segment_ids=["p9_l9"])
|
||||
]
|
||||
prov = map_segment_refs_to_provenance(
|
||||
extraction_result={"result": {"ghost": "x"}},
|
||||
segment_citations=citations,
|
||||
segment_index=idx,
|
||||
max_sources_per_field=10,
|
||||
min_confidence=0.0,
|
||||
include_bounding_boxes=True,
|
||||
source_type="value_and_context",
|
||||
)
|
||||
# Field not added when zero valid sources (spec §9.4 step).
|
||||
assert "result.ghost" not in prov.fields
|
||||
assert prov.quality_metrics["invalid_references"] == 1
|
||||
124
tests/unit/test_provenance_normalize.py
Normal file
124
tests/unit/test_provenance_normalize.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Tests for the provenance normalisers (spec §6 ReliabilityStep)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Literal
|
||||
|
||||
from ix.provenance.normalize import (
|
||||
normalize_date,
|
||||
normalize_iban,
|
||||
normalize_number,
|
||||
normalize_string,
|
||||
should_skip_text_agreement,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizeString:
|
||||
def test_uppercase_casefolded_and_punctuation_stripped(self) -> None:
|
||||
assert normalize_string(" FOO bar!!! ") == "foo bar"
|
||||
|
||||
def test_nfkc_applied_for_fullwidth(self) -> None:
|
||||
# Fullwidth capital letters should NFKC-decompose to ASCII.
|
||||
fullwidth_ubs = "\uff35\uff22\uff33" # "UBS" in U+FF00 fullwidth block
|
||||
assert normalize_string(f"{fullwidth_ubs} AG") == "ubs ag"
|
||||
|
||||
def test_whitespace_collapse(self) -> None:
|
||||
assert normalize_string("UBS Switzerland\tAG") == "ubs switzerland ag"
|
||||
|
||||
def test_strips_common_punctuation(self) -> None:
|
||||
# Colons, commas, dots, semicolons, parens, slashes.
|
||||
assert normalize_string("Hello, World. (foo); bar: baz / qux") == (
|
||||
"hello world foo bar baz qux"
|
||||
)
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
assert normalize_string("") == ""
|
||||
|
||||
|
||||
class TestNormalizeNumber:
|
||||
def test_chf_swiss_apostrophe_thousands(self) -> None:
|
||||
assert normalize_number("CHF 1'234.56") == "1234.56"
|
||||
|
||||
def test_de_de_dot_thousands_and_comma_decimal(self) -> None:
|
||||
assert normalize_number("1.234,56 EUR") == "1234.56"
|
||||
|
||||
def test_negative_sign(self) -> None:
|
||||
assert normalize_number("-123.45") == "-123.45"
|
||||
assert normalize_number("CHF -1'234.56") == "-1234.56"
|
||||
|
||||
def test_int_input(self) -> None:
|
||||
assert normalize_number(42) == "42.00"
|
||||
|
||||
def test_float_input(self) -> None:
|
||||
assert normalize_number(1234.5) == "1234.50"
|
||||
|
||||
def test_decimal_input(self) -> None:
|
||||
assert normalize_number(Decimal("1234.56")) == "1234.56"
|
||||
|
||||
def test_trailing_zero_is_canonicalised(self) -> None:
|
||||
assert normalize_number("1234.5") == "1234.50"
|
||||
|
||||
def test_no_decimal_part(self) -> None:
|
||||
assert normalize_number("1234") == "1234.00"
|
||||
|
||||
|
||||
class TestNormalizeDate:
|
||||
def test_dayfirst_dotted(self) -> None:
|
||||
assert normalize_date("31.03.2026") == "2026-03-31"
|
||||
|
||||
def test_iso_date(self) -> None:
|
||||
assert normalize_date("2026-03-31") == "2026-03-31"
|
||||
|
||||
def test_date_object(self) -> None:
|
||||
assert normalize_date(date(2026, 3, 31)) == "2026-03-31"
|
||||
|
||||
def test_datetime_object(self) -> None:
|
||||
assert normalize_date(datetime(2026, 3, 31, 10, 30)) == "2026-03-31"
|
||||
|
||||
def test_slash_variant(self) -> None:
|
||||
assert normalize_date("31/03/2026") == "2026-03-31"
|
||||
|
||||
|
||||
class TestNormalizeIban:
|
||||
def test_uppercase_and_strip_whitespace(self) -> None:
|
||||
assert normalize_iban("de 89 3704 0044 0532 0130 00") == "DE89370400440532013000"
|
||||
|
||||
def test_already_normalised(self) -> None:
|
||||
assert normalize_iban("CH9300762011623852957") == "CH9300762011623852957"
|
||||
|
||||
def test_tabs_and_newlines(self) -> None:
|
||||
assert normalize_iban("ch 93\t0076\n2011623852957") == "CH9300762011623852957"
|
||||
|
||||
|
||||
class TestShouldSkipTextAgreement:
|
||||
def test_short_string_skipped(self) -> None:
|
||||
assert should_skip_text_agreement("AB", str) is True
|
||||
|
||||
def test_long_string_not_skipped(self) -> None:
|
||||
assert should_skip_text_agreement("ABC", str) is False
|
||||
|
||||
def test_number_abs_lt_10_skipped(self) -> None:
|
||||
assert should_skip_text_agreement(0, int) is True
|
||||
assert should_skip_text_agreement(9, int) is True
|
||||
assert should_skip_text_agreement(-9, int) is True
|
||||
assert should_skip_text_agreement(9.5, float) is True
|
||||
assert should_skip_text_agreement(Decimal("9.99"), Decimal) is True
|
||||
|
||||
def test_number_abs_ge_10_not_skipped(self) -> None:
|
||||
assert should_skip_text_agreement(10, int) is False
|
||||
assert should_skip_text_agreement(-10, int) is False
|
||||
assert should_skip_text_agreement(Decimal("1234.56"), Decimal) is False
|
||||
|
||||
def test_literal_type_skipped(self) -> None:
|
||||
lit = Literal["checking", "credit", "savings"]
|
||||
assert should_skip_text_agreement("checking", lit) is True
|
||||
|
||||
def test_none_value_skipped(self) -> None:
|
||||
assert should_skip_text_agreement(None, str) is True
|
||||
assert should_skip_text_agreement(None, None) is True
|
||||
|
||||
def test_numeric_string_treated_as_string(self) -> None:
|
||||
# Short stringified numeric values still trip the short-value rule.
|
||||
assert should_skip_text_agreement("9", str) is True
|
||||
220
tests/unit/test_provenance_verify.py
Normal file
220
tests/unit/test_provenance_verify.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Tests for the reliability verifier (spec §6 ReliabilityStep)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ix.contracts import (
|
||||
ExtractionSource,
|
||||
FieldProvenance,
|
||||
ProvenanceData,
|
||||
)
|
||||
from ix.provenance.verify import apply_reliability_flags, verify_field
|
||||
|
||||
|
||||
def _make_fp(
|
||||
*,
|
||||
field_path: str,
|
||||
value: object,
|
||||
snippets: list[str],
|
||||
) -> FieldProvenance:
|
||||
return FieldProvenance(
|
||||
field_name=field_path.split(".")[-1],
|
||||
field_path=field_path,
|
||||
value=value,
|
||||
sources=[
|
||||
ExtractionSource(
|
||||
page_number=1,
|
||||
file_index=0,
|
||||
text_snippet=s,
|
||||
relevance_score=1.0,
|
||||
segment_id=f"p1_l{i}",
|
||||
)
|
||||
for i, s in enumerate(snippets)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestVerifyFieldByType:
|
||||
def test_string_substring_match(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.bank_name",
|
||||
value="UBS AG",
|
||||
snippets=["Account at UBS AG, Zurich"],
|
||||
)
|
||||
pv, ta = verify_field(fp, str, texts=[])
|
||||
assert pv is True
|
||||
assert ta is None
|
||||
|
||||
def test_string_mismatch_is_false(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.bank_name",
|
||||
value="UBS AG",
|
||||
snippets=["Credit Suisse"],
|
||||
)
|
||||
pv, _ = verify_field(fp, str, texts=[])
|
||||
assert pv is False
|
||||
|
||||
def test_number_decimal_match_ignores_currency(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.closing_balance",
|
||||
value=Decimal("1234.56"),
|
||||
snippets=["CHF 1'234.56"],
|
||||
)
|
||||
pv, _ = verify_field(fp, Decimal, texts=[])
|
||||
assert pv is True
|
||||
|
||||
def test_number_mismatch(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.closing_balance",
|
||||
value=Decimal("1234.56"),
|
||||
snippets=["CHF 9999.99"],
|
||||
)
|
||||
pv, _ = verify_field(fp, Decimal, texts=[])
|
||||
assert pv is False
|
||||
|
||||
def test_date_parse_both_sides(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.statement_date",
|
||||
value=date(2026, 3, 31),
|
||||
snippets=["Statement date: 31.03.2026"],
|
||||
)
|
||||
pv, _ = verify_field(fp, date, texts=[])
|
||||
assert pv is True
|
||||
|
||||
def test_iban_strip_and_case(self) -> None:
|
||||
# IBAN detection: field name contains "iban".
|
||||
fp = _make_fp(
|
||||
field_path="result.account_iban",
|
||||
value="CH9300762011623852957",
|
||||
snippets=["Account CH93 0076 2011 6238 5295 7"],
|
||||
)
|
||||
pv, _ = verify_field(fp, str, texts=[])
|
||||
assert pv is True
|
||||
|
||||
def test_literal_field_both_flags_none(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.account_type",
|
||||
value="checking",
|
||||
snippets=["the word checking is literally here"],
|
||||
)
|
||||
pv, ta = verify_field(fp, Literal["checking", "credit", "savings"], texts=["checking"])
|
||||
assert pv is None
|
||||
assert ta is None
|
||||
|
||||
def test_none_value_both_flags_none(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.account_iban",
|
||||
value=None,
|
||||
snippets=["whatever"],
|
||||
)
|
||||
pv, ta = verify_field(fp, str, texts=["whatever"])
|
||||
assert pv is None
|
||||
assert ta is None
|
||||
|
||||
|
||||
class TestTextAgreement:
|
||||
def test_text_agreement_with_texts_true(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.bank_name",
|
||||
value="UBS AG",
|
||||
snippets=["UBS AG"],
|
||||
)
|
||||
_, ta = verify_field(fp, str, texts=["Account at UBS AG"])
|
||||
assert ta is True
|
||||
|
||||
def test_text_agreement_with_texts_false(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.bank_name",
|
||||
value="UBS AG",
|
||||
snippets=["UBS AG"],
|
||||
)
|
||||
_, ta = verify_field(fp, str, texts=["Credit Suisse"])
|
||||
assert ta is False
|
||||
|
||||
def test_text_agreement_no_texts_is_none(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.bank_name",
|
||||
value="UBS AG",
|
||||
snippets=["UBS AG"],
|
||||
)
|
||||
_, ta = verify_field(fp, str, texts=[])
|
||||
assert ta is None
|
||||
|
||||
def test_short_value_skips_text_agreement(self) -> None:
|
||||
# 2-char string
|
||||
fp = _make_fp(
|
||||
field_path="result.code",
|
||||
value="XY",
|
||||
snippets=["code XY here"],
|
||||
)
|
||||
pv, ta = verify_field(fp, str, texts=["another XY reference"])
|
||||
# provenance_verified still runs; text_agreement is skipped.
|
||||
assert pv is True
|
||||
assert ta is None
|
||||
|
||||
def test_small_number_skips_text_agreement(self) -> None:
|
||||
fp = _make_fp(
|
||||
field_path="result.n",
|
||||
value=5,
|
||||
snippets=["value 5 here"],
|
||||
)
|
||||
pv, ta = verify_field(fp, int, texts=["the number 5"])
|
||||
assert pv is True
|
||||
assert ta is None
|
||||
|
||||
|
||||
class TestApplyReliabilityFlags:
|
||||
def test_writes_flags_and_counters(self) -> None:
|
||||
class BankHeader(BaseModel):
|
||||
bank_name: str
|
||||
account_iban: str | None = None
|
||||
closing_balance: Decimal | None = None
|
||||
account_type: Literal["checking", "credit", "savings"] | None = None
|
||||
|
||||
prov = ProvenanceData(
|
||||
fields={
|
||||
"result.bank_name": _make_fp(
|
||||
field_path="result.bank_name",
|
||||
value="UBS AG",
|
||||
snippets=["Account at UBS AG"],
|
||||
),
|
||||
"result.account_iban": _make_fp(
|
||||
field_path="result.account_iban",
|
||||
value="CH9300762011623852957",
|
||||
snippets=["IBAN CH93 0076 2011 6238 5295 7"],
|
||||
),
|
||||
"result.closing_balance": _make_fp(
|
||||
field_path="result.closing_balance",
|
||||
value=Decimal("1234.56"),
|
||||
snippets=["Closing balance CHF 1'234.56"],
|
||||
),
|
||||
"result.account_type": _make_fp(
|
||||
field_path="result.account_type",
|
||||
value="checking",
|
||||
snippets=["current account (checking)"],
|
||||
),
|
||||
},
|
||||
)
|
||||
apply_reliability_flags(prov, BankHeader, texts=["Account at UBS AG at CH9300762011623852957"])
|
||||
|
||||
fields = prov.fields
|
||||
assert fields["result.bank_name"].provenance_verified is True
|
||||
assert fields["result.bank_name"].text_agreement is True
|
||||
assert fields["result.account_iban"].provenance_verified is True
|
||||
assert fields["result.closing_balance"].provenance_verified is True
|
||||
# account_type is Literal → both flags None.
|
||||
assert fields["result.account_type"].provenance_verified is None
|
||||
assert fields["result.account_type"].text_agreement is None
|
||||
|
||||
# Counters record only True values.
|
||||
qm = prov.quality_metrics
|
||||
assert qm["verified_fields"] == 3 # all except Literal
|
||||
# text_agreement_fields counts only fields where the flag is True.
|
||||
# bank_name True; IBAN True (appears in texts after normalisation);
|
||||
# closing_balance -- '1234.56' doesn't appear in the text.
|
||||
assert qm["text_agreement_fields"] >= 1
|
||||
250
tests/unit/test_reliability_step.py
Normal file
250
tests/unit/test_reliability_step.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
"""Tests for :class:`ix.pipeline.reliability_step.ReliabilityStep` (spec §6)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from ix.contracts import (
|
||||
BoundingBox,
|
||||
Context,
|
||||
ExtractionSource,
|
||||
FieldProvenance,
|
||||
OCROptions,
|
||||
Options,
|
||||
ProvenanceData,
|
||||
ProvenanceOptions,
|
||||
RequestIX,
|
||||
ResponseIX,
|
||||
)
|
||||
from ix.contracts.response import _InternalContext
|
||||
from ix.pipeline.reliability_step import ReliabilityStep
|
||||
from ix.use_cases.bank_statement_header import BankStatementHeader
|
||||
|
||||
|
||||
def _src(
|
||||
segment_id: str,
|
||||
text: str,
|
||||
page: int = 1,
|
||||
bbox: list[float] | None = None,
|
||||
) -> ExtractionSource:
|
||||
return ExtractionSource(
|
||||
page_number=page,
|
||||
file_index=0,
|
||||
bounding_box=BoundingBox(coordinates=bbox or [0, 0, 1, 0, 1, 1, 0, 1]),
|
||||
text_snippet=text,
|
||||
relevance_score=1.0,
|
||||
segment_id=segment_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_request(
|
||||
include_provenance: bool = True, texts: list[str] | None = None
|
||||
) -> RequestIX:
|
||||
return RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="test",
|
||||
request_id="r-1",
|
||||
context=Context(files=[], texts=texts or []),
|
||||
options=Options(
|
||||
ocr=OCROptions(),
|
||||
provenance=ProvenanceOptions(include_provenance=include_provenance),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _response_with_provenance(
|
||||
fields: dict[str, FieldProvenance],
|
||||
texts: list[str] | None = None,
|
||||
) -> ResponseIX:
|
||||
resp = ResponseIX()
|
||||
resp.provenance = ProvenanceData(
|
||||
fields=fields,
|
||||
quality_metrics={},
|
||||
segment_count=10,
|
||||
granularity="line",
|
||||
)
|
||||
resp.context = _InternalContext(
|
||||
texts=texts or [],
|
||||
use_case_response=BankStatementHeader,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class TestValidate:
|
||||
async def test_skipped_when_provenance_off(self) -> None:
|
||||
step = ReliabilityStep()
|
||||
req = _make_request(include_provenance=False)
|
||||
resp = _response_with_provenance(fields={})
|
||||
assert await step.validate(req, resp) is False
|
||||
|
||||
async def test_skipped_when_no_provenance_data(self) -> None:
|
||||
step = ReliabilityStep()
|
||||
req = _make_request(include_provenance=True)
|
||||
resp = ResponseIX()
|
||||
assert await step.validate(req, resp) is False
|
||||
|
||||
async def test_runs_when_provenance_data_present(self) -> None:
|
||||
step = ReliabilityStep()
|
||||
req = _make_request(include_provenance=True)
|
||||
resp = _response_with_provenance(fields={})
|
||||
assert await step.validate(req, resp) is True
|
||||
|
||||
|
||||
class TestProcessFlags:
|
||||
async def test_string_field_verified_and_text_agreement(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="bank_name",
|
||||
field_path="result.bank_name",
|
||||
value="DKB",
|
||||
sources=[_src("p1_l0", "DKB")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.bank_name": fp},
|
||||
texts=["DKB statement content"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["DKB statement content"]), resp)
|
||||
out = resp.provenance.fields["result.bank_name"]
|
||||
assert out.provenance_verified is True
|
||||
assert out.text_agreement is True
|
||||
|
||||
async def test_literal_field_flags_none(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="account_type",
|
||||
field_path="result.account_type",
|
||||
value="checking",
|
||||
sources=[_src("p1_l0", "anything")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.account_type": fp},
|
||||
texts=["some text"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["some text"]), resp)
|
||||
out = resp.provenance.fields["result.account_type"]
|
||||
assert out.provenance_verified is None
|
||||
assert out.text_agreement is None
|
||||
|
||||
async def test_none_value_flags_none(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="account_iban",
|
||||
field_path="result.account_iban",
|
||||
value=None,
|
||||
sources=[_src("p1_l0", "IBAN blah")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.account_iban": fp},
|
||||
texts=["text"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["text"]), resp)
|
||||
out = resp.provenance.fields["result.account_iban"]
|
||||
assert out.provenance_verified is None
|
||||
assert out.text_agreement is None
|
||||
|
||||
async def test_short_value_text_agreement_skipped(self) -> None:
|
||||
# Closing balance value < 10 → short numeric skip rule.
|
||||
fp = FieldProvenance(
|
||||
field_name="opening_balance",
|
||||
field_path="result.opening_balance",
|
||||
value=Decimal("5.00"),
|
||||
sources=[_src("p1_l0", "balance 5.00")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.opening_balance": fp},
|
||||
texts=["balance 5.00"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["balance 5.00"]), resp)
|
||||
out = resp.provenance.fields["result.opening_balance"]
|
||||
assert out.provenance_verified is True # bbox cite still runs
|
||||
assert out.text_agreement is None # short-value skip
|
||||
|
||||
async def test_date_field_parses_both_sides(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="statement_date",
|
||||
field_path="result.statement_date",
|
||||
value=date(2026, 3, 31),
|
||||
sources=[_src("p1_l0", "Statement date 31.03.2026")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.statement_date": fp},
|
||||
texts=["Statement date 2026-03-31"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["Statement date 2026-03-31"]), resp)
|
||||
out = resp.provenance.fields["result.statement_date"]
|
||||
assert out.provenance_verified is True
|
||||
assert out.text_agreement is True
|
||||
|
||||
async def test_iban_field_whitespace_ignored(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="account_iban",
|
||||
field_path="result.account_iban",
|
||||
value="DE89370400440532013000",
|
||||
sources=[_src("p1_l0", "IBAN DE89 3704 0044 0532 0130 00")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.account_iban": fp},
|
||||
texts=["IBAN DE89 3704 0044 0532 0130 00"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["IBAN DE89 3704 0044 0532 0130 00"]), resp)
|
||||
out = resp.provenance.fields["result.account_iban"]
|
||||
assert out.provenance_verified is True
|
||||
assert out.text_agreement is True
|
||||
|
||||
async def test_disagreeing_snippet_sets_false(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="bank_name",
|
||||
field_path="result.bank_name",
|
||||
value="DKB",
|
||||
sources=[_src("p1_l0", "Commerzbank")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.bank_name": fp},
|
||||
texts=["Commerzbank header"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["Commerzbank header"]), resp)
|
||||
out = resp.provenance.fields["result.bank_name"]
|
||||
assert out.provenance_verified is False
|
||||
assert out.text_agreement is False
|
||||
|
||||
|
||||
class TestCounters:
|
||||
async def test_quality_metrics_counters_written(self) -> None:
|
||||
fp_ok = FieldProvenance(
|
||||
field_name="bank_name",
|
||||
field_path="result.bank_name",
|
||||
value="DKB",
|
||||
sources=[_src("p1_l0", "DKB")],
|
||||
)
|
||||
fp_bad = FieldProvenance(
|
||||
field_name="currency",
|
||||
field_path="result.currency",
|
||||
value="EUR",
|
||||
sources=[_src("p1_l1", "nothing to see")],
|
||||
)
|
||||
fp_literal = FieldProvenance(
|
||||
field_name="account_type",
|
||||
field_path="result.account_type",
|
||||
value="checking",
|
||||
sources=[_src("p1_l2", "anything")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={
|
||||
"result.bank_name": fp_ok,
|
||||
"result.currency": fp_bad,
|
||||
"result.account_type": fp_literal,
|
||||
},
|
||||
texts=["DKB statement"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["DKB statement"]), resp)
|
||||
|
||||
qm = resp.provenance.quality_metrics
|
||||
# bank_name verified+agree (2 flags), others not.
|
||||
assert qm["verified_fields"] == 1
|
||||
assert qm["text_agreement_fields"] == 1
|
||||
136
tests/unit/test_response_handler_step.py
Normal file
136
tests/unit/test_response_handler_step.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""Tests for :class:`ix.pipeline.response_handler_step.ResponseHandlerStep` (spec §8)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.contracts import (
|
||||
Context,
|
||||
Line,
|
||||
OCRDetails,
|
||||
OCROptions,
|
||||
OCRResult,
|
||||
Options,
|
||||
Page,
|
||||
RequestIX,
|
||||
ResponseIX,
|
||||
)
|
||||
from ix.contracts.response import _InternalContext
|
||||
from ix.pipeline.response_handler_step import ResponseHandlerStep
|
||||
|
||||
|
||||
def _make_request(
|
||||
*,
|
||||
include_geometries: bool = False,
|
||||
include_ocr_text: bool = False,
|
||||
) -> RequestIX:
|
||||
return RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="test",
|
||||
request_id="r-1",
|
||||
context=Context(files=[], texts=[]),
|
||||
options=Options(
|
||||
ocr=OCROptions(
|
||||
include_geometries=include_geometries,
|
||||
include_ocr_text=include_ocr_text,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _populated_response() -> ResponseIX:
|
||||
resp = ResponseIX(
|
||||
ocr_result=OCRResult(
|
||||
result=OCRDetails(
|
||||
text=None,
|
||||
pages=[
|
||||
Page(
|
||||
page_no=1,
|
||||
width=100.0,
|
||||
height=200.0,
|
||||
lines=[
|
||||
Line(text='<page file="0" number="1">', bounding_box=[]),
|
||||
Line(text="hello", bounding_box=[0, 0, 1, 0, 1, 1, 0, 1]),
|
||||
Line(text="world", bounding_box=[0, 2, 1, 2, 1, 3, 0, 3]),
|
||||
Line(text="</page>", bounding_box=[]),
|
||||
],
|
||||
),
|
||||
Page(
|
||||
page_no=2,
|
||||
width=100.0,
|
||||
height=200.0,
|
||||
lines=[
|
||||
Line(text="p2 line", bounding_box=[0, 0, 1, 0, 1, 1, 0, 1]),
|
||||
],
|
||||
),
|
||||
],
|
||||
),
|
||||
meta_data={"adapter": "fake"},
|
||||
)
|
||||
)
|
||||
resp.context = _InternalContext()
|
||||
return resp
|
||||
|
||||
|
||||
class TestValidateAlwaysTrue:
|
||||
async def test_validate_always_true(self) -> None:
|
||||
step = ResponseHandlerStep()
|
||||
req = _make_request()
|
||||
assert await step.validate(req, _populated_response()) is True
|
||||
|
||||
|
||||
class TestAttachOcrText:
|
||||
async def test_include_ocr_text_concatenates_lines(self) -> None:
|
||||
step = ResponseHandlerStep()
|
||||
req = _make_request(include_ocr_text=True, include_geometries=True)
|
||||
resp = _populated_response()
|
||||
resp = await step.process(req, resp)
|
||||
# Page tag lines excluded; real lines joined within page with \n,
|
||||
# pages with \n\n.
|
||||
text = resp.ocr_result.result.text
|
||||
assert text is not None
|
||||
assert "hello\nworld" in text
|
||||
assert "p2 line" in text
|
||||
assert "<page" not in text
|
||||
|
||||
async def test_include_ocr_text_false_leaves_text_alone(self) -> None:
|
||||
step = ResponseHandlerStep()
|
||||
req = _make_request(include_ocr_text=False, include_geometries=True)
|
||||
resp = _populated_response()
|
||||
resp.ocr_result.result.text = None
|
||||
resp = await step.process(req, resp)
|
||||
assert resp.ocr_result.result.text is None
|
||||
|
||||
|
||||
class TestStripGeometries:
|
||||
async def test_strips_pages_and_meta_when_off(self) -> None:
|
||||
step = ResponseHandlerStep()
|
||||
req = _make_request(include_geometries=False)
|
||||
resp = _populated_response()
|
||||
resp = await step.process(req, resp)
|
||||
assert resp.ocr_result.result.pages == []
|
||||
assert resp.ocr_result.meta_data == {}
|
||||
|
||||
async def test_keeps_pages_when_on(self) -> None:
|
||||
step = ResponseHandlerStep()
|
||||
req = _make_request(include_geometries=True)
|
||||
resp = _populated_response()
|
||||
pages_before = [p.page_no for p in resp.ocr_result.result.pages]
|
||||
resp = await step.process(req, resp)
|
||||
assert [p.page_no for p in resp.ocr_result.result.pages] == pages_before
|
||||
assert resp.ocr_result.meta_data == {"adapter": "fake"}
|
||||
|
||||
|
||||
class TestContextDeletion:
|
||||
async def test_context_removed(self) -> None:
|
||||
step = ResponseHandlerStep()
|
||||
req = _make_request()
|
||||
resp = _populated_response()
|
||||
resp = await step.process(req, resp)
|
||||
assert resp.context is None
|
||||
|
||||
async def test_context_not_in_model_dump(self) -> None:
|
||||
step = ResponseHandlerStep()
|
||||
req = _make_request()
|
||||
resp = _populated_response()
|
||||
resp = await step.process(req, resp)
|
||||
dump = resp.model_dump()
|
||||
assert "context" not in dump
|
||||
190
tests/unit/test_segment_index.py
Normal file
190
tests/unit/test_segment_index.py
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
"""Tests for the SegmentIndex — spec §9.1."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.contracts import Line, OCRDetails, OCRResult, Page
|
||||
from ix.segmentation import PageMetadata, SegmentIndex
|
||||
|
||||
|
||||
def _make_pages_metadata(n: int, file_index: int = 0) -> list[PageMetadata]:
|
||||
"""Build ``n`` flat-list page entries carrying only file_index."""
|
||||
return [PageMetadata(file_index=file_index) for _ in range(n)]
|
||||
|
||||
|
||||
def _line(text: str, bbox: list[float]) -> Line:
|
||||
return Line(text=text, bounding_box=bbox)
|
||||
|
||||
|
||||
def _page(page_no: int, width: float, height: float, lines: list[Line]) -> Page:
|
||||
return Page(page_no=page_no, width=width, height=height, lines=lines)
|
||||
|
||||
|
||||
class TestBuild:
|
||||
def test_ids_per_page(self) -> None:
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(1, 100.0, 200.0, [_line("hello", [0, 0, 10, 0, 10, 20, 0, 20])]),
|
||||
_page(
|
||||
2,
|
||||
100.0,
|
||||
200.0,
|
||||
[
|
||||
_line("foo", [0, 0, 10, 0, 10, 20, 0, 20]),
|
||||
_line("bar", [0, 30, 10, 30, 10, 50, 0, 50]),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=_make_pages_metadata(2),
|
||||
)
|
||||
|
||||
assert idx._ordered_ids == ["p1_l0", "p2_l0", "p2_l1"]
|
||||
|
||||
pos = idx.lookup_segment("p1_l0")
|
||||
assert pos is not None
|
||||
assert pos["page"] == 1
|
||||
assert pos["text"] == "hello"
|
||||
assert pos["file_index"] == 0
|
||||
|
||||
def test_page_tag_lines_excluded(self) -> None:
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(
|
||||
1,
|
||||
100.0,
|
||||
200.0,
|
||||
[
|
||||
_line('<page file="0" number="1">', [0, 0, 10, 0, 10, 5, 0, 5]),
|
||||
_line("first real line", [0, 10, 10, 10, 10, 20, 0, 20]),
|
||||
_line("</page>", [0, 25, 10, 25, 10, 30, 0, 30]),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=_make_pages_metadata(1),
|
||||
)
|
||||
assert idx._ordered_ids == ["p1_l0"]
|
||||
assert idx.lookup_segment("p1_l0")["text"] == "first real line" # type: ignore[index]
|
||||
|
||||
def test_lookup_unknown_returns_none(self) -> None:
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=OCRResult(result=OCRDetails(pages=[])),
|
||||
granularity="line",
|
||||
pages_metadata=[],
|
||||
)
|
||||
assert idx.lookup_segment("pX_l99") is None
|
||||
|
||||
|
||||
class TestBboxNormalization:
|
||||
def test_divides_by_page_width_and_height(self) -> None:
|
||||
# x-coords get /width, y-coords get /height.
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(
|
||||
1,
|
||||
width=200.0,
|
||||
height=400.0,
|
||||
lines=[_line("x", [50, 100, 150, 100, 150, 300, 50, 300])],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=_make_pages_metadata(1),
|
||||
)
|
||||
pos = idx.lookup_segment("p1_l0")
|
||||
assert pos is not None
|
||||
bbox = pos["bbox"]
|
||||
# Compare with a bit of float slack.
|
||||
assert bbox.coordinates == [0.25, 0.25, 0.75, 0.25, 0.75, 0.75, 0.25, 0.75]
|
||||
|
||||
|
||||
class TestPromptFormat:
|
||||
def test_tagged_lines_and_untagged_texts_appended(self) -> None:
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(
|
||||
1,
|
||||
100.0,
|
||||
200.0,
|
||||
[
|
||||
_line("line one", [0, 0, 10, 0, 10, 5, 0, 5]),
|
||||
_line("line two", [0, 10, 10, 10, 10, 15, 0, 15]),
|
||||
],
|
||||
),
|
||||
_page(2, 100.0, 200.0, [_line("line A", [0, 0, 10, 0, 10, 5, 0, 5])]),
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=_make_pages_metadata(2),
|
||||
)
|
||||
text = idx.to_prompt_text(context_texts=["extra paperless text", "another"])
|
||||
|
||||
lines = text.split("\n")
|
||||
# Tagged OCR lines first, in insertion order.
|
||||
assert lines[0] == "[p1_l0] line one"
|
||||
assert lines[1] == "[p1_l1] line two"
|
||||
assert lines[2] == "[p2_l0] line A"
|
||||
# The extra texts are appended untagged.
|
||||
assert "extra paperless text" in text
|
||||
assert "another" in text
|
||||
# Sanity: the pageless texts should appear AFTER the last tagged line.
|
||||
p2_idx = text.index("[p2_l0]")
|
||||
extra_idx = text.index("extra paperless text")
|
||||
assert extra_idx > p2_idx
|
||||
|
||||
def test_prompt_text_without_extra_texts(self) -> None:
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(1, 100.0, 200.0, [_line("only", [0, 0, 10, 0, 10, 5, 0, 5])]),
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=_make_pages_metadata(1),
|
||||
)
|
||||
text = idx.to_prompt_text(context_texts=[])
|
||||
assert text.strip() == "[p1_l0] only"
|
||||
|
||||
|
||||
class TestFileIndexPassthrough:
|
||||
def test_file_index_from_metadata(self) -> None:
|
||||
pages_meta = [
|
||||
PageMetadata(file_index=0),
|
||||
PageMetadata(file_index=1),
|
||||
]
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(1, 100.0, 200.0, [_line("a", [0, 0, 10, 0, 10, 5, 0, 5])]),
|
||||
_page(1, 100.0, 200.0, [_line("b", [0, 0, 10, 0, 10, 5, 0, 5])]),
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=pages_meta,
|
||||
)
|
||||
assert idx.lookup_segment("p1_l0")["file_index"] == 0 # type: ignore[index]
|
||||
assert idx.lookup_segment("p2_l0")["file_index"] == 1 # type: ignore[index]
|
||||
357
tests/unit/test_setup_step.py
Normal file
357
tests/unit/test_setup_step.py
Normal file
|
|
@ -0,0 +1,357 @@
|
|||
"""Tests for :class:`ix.pipeline.setup_step.SetupStep` (spec §6.1)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.contracts import (
|
||||
Context,
|
||||
FileRef,
|
||||
OCROptions,
|
||||
Options,
|
||||
ProvenanceOptions,
|
||||
RequestIX,
|
||||
ResponseIX,
|
||||
)
|
||||
from ix.contracts.request import InlineUseCase, UseCaseFieldDef
|
||||
from ix.contracts.response import _InternalContext
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.ingestion import FetchConfig
|
||||
from ix.pipeline.setup_step import SetupStep
|
||||
from ix.segmentation import PageMetadata
|
||||
|
||||
|
||||
class FakeFetcher:
|
||||
"""Captures FileRef + tmp_dir + cfg; returns a pre-set path per URL."""
|
||||
|
||||
def __init__(self, routes: dict[str, Path]) -> None:
|
||||
self.routes = routes
|
||||
self.calls: list[tuple[FileRef, Path, FetchConfig]] = []
|
||||
|
||||
async def __call__(
|
||||
self, file_ref: FileRef, tmp_dir: Path, cfg: FetchConfig
|
||||
) -> Path:
|
||||
self.calls.append((file_ref, tmp_dir, cfg))
|
||||
if file_ref.url not in self.routes:
|
||||
raise IXException(IXErrorCode.IX_000_007, detail=file_ref.url)
|
||||
return self.routes[file_ref.url]
|
||||
|
||||
|
||||
class FakeIngestor:
|
||||
"""Returns canned pages + metas; records build_pages arguments."""
|
||||
|
||||
def __init__(self, pages_by_file: list[list]) -> None:
|
||||
# Each entry corresponds to one file in the input.
|
||||
self.pages_by_file = pages_by_file
|
||||
self.build_calls: list[tuple[list, list[str]]] = []
|
||||
|
||||
def build_pages(
|
||||
self,
|
||||
files: list[tuple[Path, str]],
|
||||
texts: list[str],
|
||||
) -> tuple[list, list[PageMetadata]]:
|
||||
self.build_calls.append((files, texts))
|
||||
|
||||
# Flat out pages keyed by file_index.
|
||||
from ix.contracts import Page
|
||||
|
||||
pages: list = []
|
||||
metas: list[PageMetadata] = []
|
||||
for file_index, _ in enumerate(files):
|
||||
canned = self.pages_by_file[file_index]
|
||||
for w, h in canned:
|
||||
pages.append(
|
||||
Page(page_no=len(pages) + 1, width=w, height=h, lines=[])
|
||||
)
|
||||
metas.append(PageMetadata(file_index=file_index))
|
||||
for _ in texts:
|
||||
pages.append(Page(page_no=len(pages) + 1, width=0.0, height=0.0, lines=[]))
|
||||
metas.append(PageMetadata(file_index=None))
|
||||
return pages, metas
|
||||
|
||||
|
||||
class _AlwaysMimePdf:
|
||||
"""detect_mime replacement that always returns application/pdf."""
|
||||
|
||||
def __call__(self, path: Path) -> str:
|
||||
return "application/pdf"
|
||||
|
||||
|
||||
def _make_response() -> ResponseIX:
|
||||
return ResponseIX()
|
||||
|
||||
|
||||
def _make_cfg() -> FetchConfig:
|
||||
return FetchConfig(connect_timeout_s=1.0, read_timeout_s=2.0, max_bytes=10_000)
|
||||
|
||||
|
||||
def _make_request(
|
||||
files: list[str | FileRef] | None = None,
|
||||
texts: list[str] | None = None,
|
||||
use_case: str = "bank_statement_header",
|
||||
) -> RequestIX:
|
||||
return RequestIX(
|
||||
use_case=use_case,
|
||||
ix_client_id="test",
|
||||
request_id="r-1",
|
||||
context=Context(files=files or [], texts=texts or []),
|
||||
options=Options(
|
||||
ocr=OCROptions(use_ocr=True),
|
||||
provenance=ProvenanceOptions(include_provenance=True),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestValidate:
|
||||
async def test_empty_context_raises_IX_000_002(self, tmp_path: Path) -> None:
|
||||
step = SetupStep(
|
||||
fetcher=FakeFetcher({}),
|
||||
ingestor=FakeIngestor([]),
|
||||
tmp_dir=tmp_path,
|
||||
fetch_config=_make_cfg(),
|
||||
mime_detector=_AlwaysMimePdf(),
|
||||
)
|
||||
req = _make_request(files=[], texts=[])
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.validate(req, _make_response())
|
||||
assert ei.value.code is IXErrorCode.IX_000_002
|
||||
|
||||
|
||||
class TestProcessHappyPath:
|
||||
async def test_files_downloaded_mime_checked_use_case_loaded(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
routes = {"http://host/a.pdf": tmp_path / "a.pdf"}
|
||||
for p in routes.values():
|
||||
p.write_bytes(b"%PDF-1.4")
|
||||
fetcher = FakeFetcher(routes)
|
||||
ingestor = FakeIngestor([[(200.0, 300.0), (200.0, 300.0)]])
|
||||
step = SetupStep(
|
||||
fetcher=fetcher,
|
||||
ingestor=ingestor,
|
||||
tmp_dir=tmp_path / "work",
|
||||
fetch_config=_make_cfg(),
|
||||
mime_detector=_AlwaysMimePdf(),
|
||||
)
|
||||
req = _make_request(files=["http://host/a.pdf"])
|
||||
resp = _make_response()
|
||||
assert await step.validate(req, resp) is True
|
||||
resp = await step.process(req, resp)
|
||||
|
||||
# Fetcher invoked once with the URL wrapped in a FileRef.
|
||||
assert len(fetcher.calls) == 1
|
||||
assert fetcher.calls[0][0].url == "http://host/a.pdf"
|
||||
|
||||
# Ingestor received [(local_path, mime)] + empty texts.
|
||||
assert len(ingestor.build_calls) == 1
|
||||
files, texts = ingestor.build_calls[0]
|
||||
assert files == [(routes["http://host/a.pdf"], "application/pdf")]
|
||||
assert texts == []
|
||||
|
||||
# Context populated.
|
||||
ctx = resp.context
|
||||
assert ctx is not None
|
||||
assert len(getattr(ctx, "pages", [])) == 2
|
||||
assert len(getattr(ctx, "page_metadata", [])) == 2
|
||||
assert getattr(ctx, "texts", None) == []
|
||||
assert getattr(ctx, "files", None) is not None
|
||||
|
||||
# Use case echoed.
|
||||
assert resp.use_case_name == "Bank Statement Header"
|
||||
|
||||
async def test_fileref_headers_pass_through(self, tmp_path: Path) -> None:
|
||||
routes = {"http://host/with-auth.pdf": tmp_path / "f.pdf"}
|
||||
for p in routes.values():
|
||||
p.write_bytes(b"%PDF-1.4")
|
||||
fetcher = FakeFetcher(routes)
|
||||
ingestor = FakeIngestor([[(10.0, 10.0)]])
|
||||
step = SetupStep(
|
||||
fetcher=fetcher,
|
||||
ingestor=ingestor,
|
||||
tmp_dir=tmp_path / "work",
|
||||
fetch_config=_make_cfg(),
|
||||
mime_detector=_AlwaysMimePdf(),
|
||||
)
|
||||
req = _make_request(
|
||||
files=[FileRef(url="http://host/with-auth.pdf", headers={"Authorization": "Token z"})],
|
||||
)
|
||||
await step.process(req, _make_response())
|
||||
fr = fetcher.calls[0][0]
|
||||
assert fr.headers == {"Authorization": "Token z"}
|
||||
|
||||
|
||||
class TestProcessErrors:
|
||||
async def test_unsupported_mime_raises_IX_000_005(self, tmp_path: Path) -> None:
|
||||
routes = {"http://host/a.txt": tmp_path / "a.txt"}
|
||||
routes["http://host/a.txt"].write_bytes(b"hello")
|
||||
fetcher = FakeFetcher(routes)
|
||||
ingestor = FakeIngestor([[(10.0, 10.0)]])
|
||||
|
||||
class _TextMime:
|
||||
def __call__(self, path: Path) -> str:
|
||||
return "text/plain"
|
||||
|
||||
step = SetupStep(
|
||||
fetcher=fetcher,
|
||||
ingestor=ingestor,
|
||||
tmp_dir=tmp_path / "work",
|
||||
fetch_config=_make_cfg(),
|
||||
mime_detector=_TextMime(),
|
||||
)
|
||||
req = _make_request(files=["http://host/a.txt"])
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.process(req, _make_response())
|
||||
assert ei.value.code is IXErrorCode.IX_000_005
|
||||
|
||||
async def test_unknown_use_case_raises_IX_001_001(self, tmp_path: Path) -> None:
|
||||
step = SetupStep(
|
||||
fetcher=FakeFetcher({}),
|
||||
ingestor=FakeIngestor([]),
|
||||
tmp_dir=tmp_path / "work",
|
||||
fetch_config=_make_cfg(),
|
||||
mime_detector=_AlwaysMimePdf(),
|
||||
)
|
||||
req = _make_request(files=[], texts=["hello"], use_case="nope")
|
||||
# Validate passes (we have context). Process should raise IX_001_001.
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.process(req, _make_response())
|
||||
assert ei.value.code is IXErrorCode.IX_001_001
|
||||
|
||||
|
||||
class TestTextOnly:
|
||||
async def test_texts_only_loads_use_case_and_builds_text_pages(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
fetcher = FakeFetcher({})
|
||||
ingestor = FakeIngestor([])
|
||||
step = SetupStep(
|
||||
fetcher=fetcher,
|
||||
ingestor=ingestor,
|
||||
tmp_dir=tmp_path / "work",
|
||||
fetch_config=_make_cfg(),
|
||||
mime_detector=_AlwaysMimePdf(),
|
||||
)
|
||||
req = _make_request(files=[], texts=["hello", "there"])
|
||||
resp = _make_response()
|
||||
assert await step.validate(req, resp) is True
|
||||
resp = await step.process(req, resp)
|
||||
|
||||
assert fetcher.calls == []
|
||||
assert ingestor.build_calls[0][1] == ["hello", "there"]
|
||||
ctx = resp.context
|
||||
assert ctx is not None
|
||||
assert ctx.texts == ["hello", "there"]
|
||||
|
||||
|
||||
class TestInlineUseCase:
|
||||
def _make_inline_request(
|
||||
self,
|
||||
inline: InlineUseCase,
|
||||
use_case: str = "adhoc-label",
|
||||
texts: list[str] | None = None,
|
||||
) -> RequestIX:
|
||||
return RequestIX(
|
||||
use_case=use_case,
|
||||
use_case_inline=inline,
|
||||
ix_client_id="test",
|
||||
request_id="r-inline",
|
||||
context=Context(files=[], texts=texts or ["hello"]),
|
||||
options=Options(
|
||||
ocr=OCROptions(use_ocr=True),
|
||||
provenance=ProvenanceOptions(include_provenance=True),
|
||||
),
|
||||
)
|
||||
|
||||
async def test_inline_use_case_overrides_registry(self, tmp_path: Path) -> None:
|
||||
fetcher = FakeFetcher({})
|
||||
ingestor = FakeIngestor([])
|
||||
step = SetupStep(
|
||||
fetcher=fetcher,
|
||||
ingestor=ingestor,
|
||||
tmp_dir=tmp_path / "work",
|
||||
fetch_config=_make_cfg(),
|
||||
mime_detector=_AlwaysMimePdf(),
|
||||
)
|
||||
inline = InlineUseCase(
|
||||
use_case_name="adhoc",
|
||||
system_prompt="Extract things.",
|
||||
fields=[
|
||||
UseCaseFieldDef(name="vendor", type="str", required=True),
|
||||
UseCaseFieldDef(name="amount", type="decimal"),
|
||||
],
|
||||
)
|
||||
req = self._make_inline_request(inline)
|
||||
resp = _make_response()
|
||||
resp = await step.process(req, resp)
|
||||
|
||||
ctx = resp.context
|
||||
assert ctx is not None
|
||||
# The response class must have been built from our field list.
|
||||
resp_cls = ctx.use_case_response # type: ignore[union-attr]
|
||||
assert set(resp_cls.model_fields.keys()) == {"vendor", "amount"}
|
||||
# Public display name reflects the inline label.
|
||||
assert resp.use_case_name == "adhoc"
|
||||
|
||||
async def test_inline_precedence_when_both_set(self, tmp_path: Path) -> None:
|
||||
# ``use_case`` is a valid registered name; ``use_case_inline`` is also
|
||||
# present. Inline MUST win (documented precedence).
|
||||
fetcher = FakeFetcher({})
|
||||
ingestor = FakeIngestor([])
|
||||
step = SetupStep(
|
||||
fetcher=fetcher,
|
||||
ingestor=ingestor,
|
||||
tmp_dir=tmp_path / "work",
|
||||
fetch_config=_make_cfg(),
|
||||
mime_detector=_AlwaysMimePdf(),
|
||||
)
|
||||
inline = InlineUseCase(
|
||||
use_case_name="override",
|
||||
system_prompt="override prompt",
|
||||
fields=[UseCaseFieldDef(name="just_me", type="str", required=True)],
|
||||
)
|
||||
req = self._make_inline_request(
|
||||
inline, use_case="bank_statement_header"
|
||||
)
|
||||
resp = await step.process(req, _make_response())
|
||||
resp_cls = resp.context.use_case_response # type: ignore[union-attr]
|
||||
assert set(resp_cls.model_fields.keys()) == {"just_me"}
|
||||
|
||||
async def test_inline_with_bad_field_raises_ix_001_001(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
fetcher = FakeFetcher({})
|
||||
ingestor = FakeIngestor([])
|
||||
step = SetupStep(
|
||||
fetcher=fetcher,
|
||||
ingestor=ingestor,
|
||||
tmp_dir=tmp_path / "work",
|
||||
fetch_config=_make_cfg(),
|
||||
mime_detector=_AlwaysMimePdf(),
|
||||
)
|
||||
inline = InlineUseCase(
|
||||
use_case_name="bad",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="123bad", type="str")],
|
||||
)
|
||||
req = self._make_inline_request(inline)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await step.process(req, _make_response())
|
||||
assert ei.value.code is IXErrorCode.IX_001_001
|
||||
|
||||
|
||||
class TestInternalContextShape:
|
||||
async def test_context_is_internal_context_instance(self, tmp_path: Path) -> None:
|
||||
fetcher = FakeFetcher({})
|
||||
ingestor = FakeIngestor([])
|
||||
step = SetupStep(
|
||||
fetcher=fetcher,
|
||||
ingestor=ingestor,
|
||||
tmp_dir=tmp_path / "work",
|
||||
fetch_config=_make_cfg(),
|
||||
mime_detector=_AlwaysMimePdf(),
|
||||
)
|
||||
req = _make_request(files=[], texts=["hello"])
|
||||
resp = await step.process(req, _make_response())
|
||||
assert isinstance(resp.context, _InternalContext)
|
||||
238
tests/unit/test_surya_client.py
Normal file
238
tests/unit/test_surya_client.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""Tests for :class:`SuryaOCRClient` — hermetic, no model download.
|
||||
|
||||
The real Surya predictors are patched out with :class:`unittest.mock.MagicMock`
|
||||
that return trivially-shaped line objects. The tests assert the client's
|
||||
translation layer — flattening polygons, mapping text_lines → ``Line``,
|
||||
preserving ``page_no``/``width``/``height`` per input page.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.contracts import Page
|
||||
from ix.ocr.surya_client import SuryaOCRClient
|
||||
from ix.segmentation import PageMetadata
|
||||
|
||||
|
||||
def _make_surya_line(text: str, polygon: list[list[float]]) -> SimpleNamespace:
|
||||
"""Mimic ``surya.recognition.schema.TextLine`` duck-typing-style."""
|
||||
return SimpleNamespace(text=text, polygon=polygon, confidence=0.95)
|
||||
|
||||
|
||||
def _make_surya_ocr_result(lines: list[SimpleNamespace]) -> SimpleNamespace:
|
||||
"""Mimic ``surya.recognition.schema.OCRResult``."""
|
||||
return SimpleNamespace(text_lines=lines, image_bbox=[0, 0, 100, 100])
|
||||
|
||||
|
||||
class TestOCRBuildsOCRResultFromMockedPredictors:
|
||||
async def test_one_image_one_line_flatten_polygon(self, tmp_path: Path) -> None:
|
||||
img_path = tmp_path / "a.png"
|
||||
_write_tiny_png(img_path)
|
||||
|
||||
mock_line = _make_surya_line(
|
||||
text="hello",
|
||||
polygon=[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
|
||||
)
|
||||
mock_predictor = MagicMock(
|
||||
return_value=[_make_surya_ocr_result([mock_line])]
|
||||
)
|
||||
|
||||
client = SuryaOCRClient()
|
||||
# Skip the real warm_up; inject the mock directly.
|
||||
client._recognition_predictor = mock_predictor
|
||||
client._detection_predictor = MagicMock()
|
||||
|
||||
pages = [Page(page_no=1, width=100.0, height=50.0, lines=[])]
|
||||
result = await client.ocr(
|
||||
pages,
|
||||
files=[(img_path, "image/png")],
|
||||
page_metadata=[PageMetadata(file_index=0)],
|
||||
)
|
||||
|
||||
assert len(result.result.pages) == 1
|
||||
out_page = result.result.pages[0]
|
||||
assert out_page.page_no == 1
|
||||
assert out_page.width == 100.0
|
||||
assert out_page.height == 50.0
|
||||
assert len(out_page.lines) == 1
|
||||
assert out_page.lines[0].text == "hello"
|
||||
assert out_page.lines[0].bounding_box == [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0
|
||||
]
|
||||
|
||||
async def test_multiple_pages_preserves_order(self, tmp_path: Path) -> None:
|
||||
img_a = tmp_path / "a.png"
|
||||
img_b = tmp_path / "b.png"
|
||||
_write_tiny_png(img_a)
|
||||
_write_tiny_png(img_b)
|
||||
|
||||
mock_predictor = MagicMock(
|
||||
return_value=[
|
||||
_make_surya_ocr_result(
|
||||
[_make_surya_line("a-line", [[0, 0], [1, 0], [1, 1], [0, 1]])]
|
||||
),
|
||||
_make_surya_ocr_result(
|
||||
[_make_surya_line("b-line", [[0, 0], [1, 0], [1, 1], [0, 1]])]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
client = SuryaOCRClient()
|
||||
client._recognition_predictor = mock_predictor
|
||||
client._detection_predictor = MagicMock()
|
||||
|
||||
pages = [
|
||||
Page(page_no=1, width=10.0, height=20.0, lines=[]),
|
||||
Page(page_no=2, width=10.0, height=20.0, lines=[]),
|
||||
]
|
||||
result = await client.ocr(
|
||||
pages,
|
||||
files=[(img_a, "image/png"), (img_b, "image/png")],
|
||||
page_metadata=[
|
||||
PageMetadata(file_index=0),
|
||||
PageMetadata(file_index=1),
|
||||
],
|
||||
)
|
||||
|
||||
assert [p.lines[0].text for p in result.result.pages] == ["a-line", "b-line"]
|
||||
|
||||
async def test_lazy_warm_up_on_first_ocr(self, tmp_path: Path) -> None:
|
||||
img = tmp_path / "x.png"
|
||||
_write_tiny_png(img)
|
||||
|
||||
client = SuryaOCRClient()
|
||||
|
||||
# Use patch.object on the instance's warm_up so we don't need real
|
||||
# Surya module loading.
|
||||
with patch.object(client, "warm_up", autospec=True) as mocked_warm_up:
|
||||
# After warm_up is called, the predictors must be assigned.
|
||||
def fake_warm_up(self: SuryaOCRClient) -> None:
|
||||
self._recognition_predictor = MagicMock(
|
||||
return_value=[
|
||||
_make_surya_ocr_result(
|
||||
[
|
||||
_make_surya_line(
|
||||
"hi", [[0, 0], [1, 0], [1, 1], [0, 1]]
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
self._detection_predictor = MagicMock()
|
||||
|
||||
mocked_warm_up.side_effect = lambda: fake_warm_up(client)
|
||||
|
||||
pages = [Page(page_no=1, width=10.0, height=10.0, lines=[])]
|
||||
await client.ocr(
|
||||
pages,
|
||||
files=[(img, "image/png")],
|
||||
page_metadata=[PageMetadata(file_index=0)],
|
||||
)
|
||||
mocked_warm_up.assert_called_once()
|
||||
|
||||
|
||||
class TestSelfcheck:
|
||||
async def test_selfcheck_ok_with_mocked_predictors(self) -> None:
|
||||
client = SuryaOCRClient()
|
||||
client._recognition_predictor = MagicMock(
|
||||
return_value=[_make_surya_ocr_result([])]
|
||||
)
|
||||
client._detection_predictor = MagicMock()
|
||||
assert await client.selfcheck() == "ok"
|
||||
|
||||
async def test_selfcheck_fail_when_predictor_raises(self) -> None:
|
||||
client = SuryaOCRClient()
|
||||
client._recognition_predictor = MagicMock(
|
||||
side_effect=RuntimeError("cuda broken")
|
||||
)
|
||||
client._detection_predictor = MagicMock()
|
||||
assert await client.selfcheck() == "fail"
|
||||
|
||||
|
||||
def _write_tiny_png(path: Path) -> None:
|
||||
"""Write a 2x2 white PNG so PIL can open it."""
|
||||
from PIL import Image
|
||||
|
||||
Image.new("RGB", (2, 2), color="white").save(path, format="PNG")
|
||||
|
||||
|
||||
class TestGpuAvailableFlag:
|
||||
def test_default_is_none(self) -> None:
|
||||
client = SuryaOCRClient()
|
||||
assert client.gpu_available is None
|
||||
|
||||
def test_warm_up_probes_cuda_true(self) -> None:
|
||||
"""When torch reports CUDA, warm_up records True on the instance."""
|
||||
|
||||
client = SuryaOCRClient()
|
||||
fake_foundation = MagicMock()
|
||||
fake_recognition = MagicMock()
|
||||
fake_detection = MagicMock()
|
||||
fake_torch = SimpleNamespace(
|
||||
cuda=SimpleNamespace(is_available=lambda: True)
|
||||
)
|
||||
|
||||
module_patches = {
|
||||
"surya.detection": SimpleNamespace(
|
||||
DetectionPredictor=lambda: fake_detection
|
||||
),
|
||||
"surya.foundation": SimpleNamespace(
|
||||
FoundationPredictor=lambda: fake_foundation
|
||||
),
|
||||
"surya.recognition": SimpleNamespace(
|
||||
RecognitionPredictor=lambda _f: fake_recognition
|
||||
),
|
||||
"torch": fake_torch,
|
||||
}
|
||||
with patch.dict("sys.modules", module_patches):
|
||||
client.warm_up()
|
||||
|
||||
assert client.gpu_available is True
|
||||
assert client._recognition_predictor is fake_recognition
|
||||
assert client._detection_predictor is fake_detection
|
||||
|
||||
def test_warm_up_probes_cuda_false(self) -> None:
|
||||
"""CPU-mode host → warm_up records False."""
|
||||
|
||||
client = SuryaOCRClient()
|
||||
fake_torch = SimpleNamespace(
|
||||
cuda=SimpleNamespace(is_available=lambda: False)
|
||||
)
|
||||
module_patches = {
|
||||
"surya.detection": SimpleNamespace(
|
||||
DetectionPredictor=lambda: MagicMock()
|
||||
),
|
||||
"surya.foundation": SimpleNamespace(
|
||||
FoundationPredictor=lambda: MagicMock()
|
||||
),
|
||||
"surya.recognition": SimpleNamespace(
|
||||
RecognitionPredictor=lambda _f: MagicMock()
|
||||
),
|
||||
"torch": fake_torch,
|
||||
}
|
||||
with patch.dict("sys.modules", module_patches):
|
||||
client.warm_up()
|
||||
|
||||
assert client.gpu_available is False
|
||||
|
||||
def test_warm_up_is_idempotent_for_probe(self) -> None:
|
||||
"""Second warm_up short-circuits; probed flag is preserved."""
|
||||
|
||||
client = SuryaOCRClient()
|
||||
client._recognition_predictor = MagicMock()
|
||||
client._detection_predictor = MagicMock()
|
||||
client.gpu_available = True
|
||||
|
||||
# No module patches — warm_up must NOT touch sys.modules or torch.
|
||||
client.warm_up()
|
||||
assert client.gpu_available is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("unused", [None]) # keep pytest happy if file ever runs alone
|
||||
def test_module_imports(unused: None) -> None:
|
||||
assert SuryaOCRClient is not None
|
||||
|
|
@ -12,7 +12,7 @@ class TestRequest:
|
|||
def test_defaults(self) -> None:
|
||||
r = Request()
|
||||
assert r.use_case_name == "Bank Statement Header"
|
||||
assert r.default_model == "gpt-oss:20b"
|
||||
assert r.default_model == "qwen3:14b"
|
||||
# Stable substring for agent/worker tests that want to confirm the
|
||||
# prompt is what they think it is.
|
||||
assert "extract header metadata" in r.system_prompt
|
||||
|
|
|
|||
313
tests/unit/test_use_case_inline.py
Normal file
313
tests/unit/test_use_case_inline.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
"""Tests for :mod:`ix.use_cases.inline` — dynamic Pydantic class builder.
|
||||
|
||||
The builder takes an :class:`InlineUseCase` (carried on :class:`RequestIX` as
|
||||
``use_case_inline``) and produces a fresh ``(RequestClass, ResponseClass)``
|
||||
pair that the pipeline can consume in place of a registered use case.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from ix.contracts.request import InlineUseCase, UseCaseFieldDef
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.use_cases.inline import build_use_case_classes
|
||||
|
||||
|
||||
class TestUseCaseFieldDef:
|
||||
def test_minimal(self) -> None:
|
||||
fd = UseCaseFieldDef(name="foo", type="str")
|
||||
assert fd.name == "foo"
|
||||
assert fd.type == "str"
|
||||
assert fd.required is False
|
||||
assert fd.description is None
|
||||
assert fd.choices is None
|
||||
|
||||
def test_extra_forbidden(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
UseCaseFieldDef.model_validate(
|
||||
{"name": "foo", "type": "str", "bogus": 1}
|
||||
)
|
||||
|
||||
def test_invalid_type_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
UseCaseFieldDef.model_validate({"name": "foo", "type": "list"})
|
||||
|
||||
|
||||
class TestInlineUseCaseRoundtrip:
|
||||
def test_json_roundtrip(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="Vendor Total",
|
||||
system_prompt="Extract invoice total and vendor.",
|
||||
default_model="qwen3:14b",
|
||||
fields=[
|
||||
UseCaseFieldDef(name="vendor", type="str", required=True),
|
||||
UseCaseFieldDef(
|
||||
name="total",
|
||||
type="decimal",
|
||||
required=True,
|
||||
description="total amount due",
|
||||
),
|
||||
UseCaseFieldDef(
|
||||
name="currency",
|
||||
type="str",
|
||||
choices=["USD", "EUR", "CHF"],
|
||||
),
|
||||
],
|
||||
)
|
||||
dumped = iuc.model_dump_json()
|
||||
round = InlineUseCase.model_validate_json(dumped)
|
||||
assert round == iuc
|
||||
# JSON is well-formed
|
||||
json.loads(dumped)
|
||||
|
||||
def test_extra_forbidden(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
InlineUseCase.model_validate(
|
||||
{
|
||||
"use_case_name": "X",
|
||||
"system_prompt": "p",
|
||||
"fields": [],
|
||||
"bogus": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestBuildBasicTypes:
|
||||
@pytest.mark.parametrize(
|
||||
"type_name, sample_value, bad_value",
|
||||
[
|
||||
("str", "hello", 123),
|
||||
("int", 42, "nope"),
|
||||
("float", 3.14, "nope"),
|
||||
("bool", True, "nope"),
|
||||
],
|
||||
)
|
||||
def test_simple_type(
|
||||
self, type_name: str, sample_value: object, bad_value: object
|
||||
) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="field", type=type_name, required=True)],
|
||||
)
|
||||
_req_cls, resp_cls = build_use_case_classes(iuc)
|
||||
instance = resp_cls(field=sample_value)
|
||||
assert instance.field == sample_value
|
||||
with pytest.raises(ValidationError):
|
||||
resp_cls(field=bad_value)
|
||||
|
||||
def test_decimal_type(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="amount", type="decimal", required=True)],
|
||||
)
|
||||
_req_cls, resp_cls = build_use_case_classes(iuc)
|
||||
instance = resp_cls(amount="12.34")
|
||||
assert isinstance(instance.amount, Decimal)
|
||||
assert instance.amount == Decimal("12.34")
|
||||
|
||||
def test_date_type(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="d", type="date", required=True)],
|
||||
)
|
||||
_req_cls, resp_cls = build_use_case_classes(iuc)
|
||||
instance = resp_cls(d="2026-04-18")
|
||||
assert instance.d == date(2026, 4, 18)
|
||||
|
||||
def test_datetime_type(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="ts", type="datetime", required=True)],
|
||||
)
|
||||
_req_cls, resp_cls = build_use_case_classes(iuc)
|
||||
instance = resp_cls(ts="2026-04-18T10:00:00")
|
||||
assert isinstance(instance.ts, datetime)
|
||||
|
||||
|
||||
class TestOptionalVsRequired:
|
||||
def test_required_field_cannot_be_missing(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="must", type="str", required=True)],
|
||||
)
|
||||
_req_cls, resp_cls = build_use_case_classes(iuc)
|
||||
with pytest.raises(ValidationError):
|
||||
resp_cls()
|
||||
|
||||
def test_optional_field_defaults_to_none(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="maybe", type="str", required=False)],
|
||||
)
|
||||
_req_cls, resp_cls = build_use_case_classes(iuc)
|
||||
instance = resp_cls()
|
||||
assert instance.maybe is None
|
||||
|
||||
def test_optional_field_schema_allows_null(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="maybe", type="str", required=False)],
|
||||
)
|
||||
_req_cls, resp_cls = build_use_case_classes(iuc)
|
||||
schema = resp_cls.model_json_schema()
|
||||
# "maybe" accepts string or null
|
||||
prop = schema["properties"]["maybe"]
|
||||
# Pydantic may express Optional as anyOf [str, null] or a type list.
|
||||
# Either is fine — just assert null is allowed somewhere.
|
||||
dumped = json.dumps(prop)
|
||||
assert "null" in dumped
|
||||
|
||||
|
||||
class TestChoices:
|
||||
def test_choices_for_str_produces_literal(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[
|
||||
UseCaseFieldDef(
|
||||
name="kind",
|
||||
type="str",
|
||||
required=True,
|
||||
choices=["a", "b", "c"],
|
||||
)
|
||||
],
|
||||
)
|
||||
_req_cls, resp_cls = build_use_case_classes(iuc)
|
||||
inst = resp_cls(kind="a")
|
||||
assert inst.kind == "a"
|
||||
with pytest.raises(ValidationError):
|
||||
resp_cls(kind="nope")
|
||||
schema = resp_cls.model_json_schema()
|
||||
# enum or const wind up in a referenced definition; walk the schema
|
||||
dumped = json.dumps(schema)
|
||||
assert '"a"' in dumped and '"b"' in dumped and '"c"' in dumped
|
||||
|
||||
def test_choices_for_non_str_raises_ix_001_001(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[
|
||||
UseCaseFieldDef(
|
||||
name="kind",
|
||||
type="int",
|
||||
required=True,
|
||||
choices=["1", "2"],
|
||||
)
|
||||
],
|
||||
)
|
||||
with pytest.raises(IXException) as exc:
|
||||
build_use_case_classes(iuc)
|
||||
assert exc.value.code is IXErrorCode.IX_001_001
|
||||
|
||||
def test_empty_choices_list_ignored(self) -> None:
|
||||
# An explicitly empty list is as-if choices were unset; builder must
|
||||
# not break. If the caller sent choices=[] we treat the field as
|
||||
# plain str.
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[
|
||||
UseCaseFieldDef(
|
||||
name="kind", type="str", required=True, choices=[]
|
||||
)
|
||||
],
|
||||
)
|
||||
_req_cls, resp_cls = build_use_case_classes(iuc)
|
||||
inst = resp_cls(kind="anything")
|
||||
assert inst.kind == "anything"
|
||||
|
||||
|
||||
class TestValidation:
|
||||
def test_duplicate_field_names_raise(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[
|
||||
UseCaseFieldDef(name="foo", type="str"),
|
||||
UseCaseFieldDef(name="foo", type="int"),
|
||||
],
|
||||
)
|
||||
with pytest.raises(IXException) as exc:
|
||||
build_use_case_classes(iuc)
|
||||
assert exc.value.code is IXErrorCode.IX_001_001
|
||||
|
||||
def test_invalid_field_name_raises(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="123abc", type="str")],
|
||||
)
|
||||
with pytest.raises(IXException) as exc:
|
||||
build_use_case_classes(iuc)
|
||||
assert exc.value.code is IXErrorCode.IX_001_001
|
||||
|
||||
def test_empty_fields_list_raises(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X", system_prompt="p", fields=[]
|
||||
)
|
||||
with pytest.raises(IXException) as exc:
|
||||
build_use_case_classes(iuc)
|
||||
assert exc.value.code is IXErrorCode.IX_001_001
|
||||
|
||||
|
||||
class TestResponseClassNaming:
|
||||
def test_class_name_sanitised(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="Bank / Statement — header!",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="x", type="str")],
|
||||
)
|
||||
_req_cls, resp_cls = build_use_case_classes(iuc)
|
||||
assert resp_cls.__name__.startswith("Inline_")
|
||||
# Only alphanumerics and underscores remain.
|
||||
assert all(c.isalnum() or c == "_" for c in resp_cls.__name__)
|
||||
|
||||
def test_fresh_instances_per_call(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="X",
|
||||
system_prompt="p",
|
||||
fields=[UseCaseFieldDef(name="x", type="str")],
|
||||
)
|
||||
req1, resp1 = build_use_case_classes(iuc)
|
||||
req2, resp2 = build_use_case_classes(iuc)
|
||||
assert resp1 is not resp2
|
||||
assert req1 is not req2
|
||||
|
||||
|
||||
class TestRequestClassShape:
|
||||
def test_request_class_exposes_prompt_and_default(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="My Case",
|
||||
system_prompt="Follow directions.",
|
||||
default_model="qwen3:14b",
|
||||
fields=[UseCaseFieldDef(name="x", type="str")],
|
||||
)
|
||||
req_cls, _resp_cls = build_use_case_classes(iuc)
|
||||
inst = req_cls()
|
||||
assert inst.use_case_name == "My Case"
|
||||
assert inst.system_prompt == "Follow directions."
|
||||
assert inst.default_model == "qwen3:14b"
|
||||
assert issubclass(req_cls, BaseModel)
|
||||
|
||||
def test_default_model_none_when_unset(self) -> None:
|
||||
iuc = InlineUseCase(
|
||||
use_case_name="My Case",
|
||||
system_prompt="Follow directions.",
|
||||
fields=[UseCaseFieldDef(name="x", type="str")],
|
||||
)
|
||||
req_cls, _resp_cls = build_use_case_classes(iuc)
|
||||
inst = req_cls()
|
||||
assert inst.default_model is None
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue