"""REST routes (spec §5). The routes depend on two injected objects: * a session factory (``get_session_factory_dep``): swapped in tests so we can use the fixture's per-test engine instead of the lazy process-wide one in ``ix.store.engine``. * a :class:`Probes` bundle (``get_probes``): each probe returns the per-subsystem state string used by ``/healthz``. Tests inject canned probes; Chunk 4 wires the real Ollama/Surya ones. ``/healthz`` has a strict 2-second postgres timeout — we use an ``asyncio.wait_for`` around a ``SELECT 1`` roundtrip so a broken pool or a hung connection can't wedge the healthcheck endpoint. """ from __future__ import annotations import asyncio from collections.abc import Callable from dataclasses import dataclass from datetime import UTC, datetime, timedelta from typing import Annotated, Literal from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Response from sqlalchemy import func, select, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from ix.adapters.rest.schemas import HealthStatus, JobSubmitResponse, MetricsResponse from ix.contracts.job import Job from ix.contracts.request import RequestIX from ix.store import jobs_repo from ix.store.engine import get_session_factory from ix.store.models import IxJob @dataclass class Probes: """Injected subsystem-probe callables for ``/healthz``. Each callable returns the literal status string expected in the body. Probes are sync by design — none of the real ones need awaits today and keeping them sync lets tests pass plain lambdas. Real probes that need async work run the call through ``asyncio.run_in_executor`` inside the callable (Chunk 4). """ ollama: Callable[[], Literal["ok", "degraded", "fail"]] ocr: Callable[[], Literal["ok", "fail"]] def get_session_factory_dep() -> async_sessionmaker[AsyncSession]: """Default DI: the process-wide store factory. Tests override this.""" return get_session_factory() def get_probes() -> Probes: """Default DI: a pair of ``fail`` probes. Production wiring (Chunk 4) overrides this factory with real Ollama + Surya probes at app-startup time. Integration tests override via ``app.dependency_overrides[get_probes]`` with a canned ``ok`` pair. The default here ensures a mis-wired deployment surfaces clearly in ``/healthz`` rather than claiming everything is fine by accident. """ return Probes(ollama=lambda: "fail", ocr=lambda: "fail") router = APIRouter() @router.post("/jobs", response_model=JobSubmitResponse, status_code=201) async def submit_job( request: RequestIX, response: Response, session_factory: Annotated[ async_sessionmaker[AsyncSession], Depends(get_session_factory_dep) ], ) -> JobSubmitResponse: """Submit a new job. Per spec §5: 201 on first insert, 200 on idempotent re-submit of an existing ``(client_id, request_id)`` pair. We detect the second case by snapshotting the pre-insert row set and comparing ``created_at``. """ async with session_factory() as session: existing = await jobs_repo.get_by_correlation( session, request.ix_client_id, request.request_id ) job = await jobs_repo.insert_pending( session, request, callback_url=request.callback_url ) await session.commit() if existing is not None: # Idempotent re-submit — flip to 200. FastAPI's declared status_code # is 201, but setting response.status_code overrides it per-call. response.status_code = 200 return JobSubmitResponse(job_id=job.job_id, ix_id=job.ix_id, status=job.status) @router.get("/jobs/{job_id}", response_model=Job) async def get_job( job_id: UUID, session_factory: Annotated[ async_sessionmaker[AsyncSession], Depends(get_session_factory_dep) ], ) -> Job: async with session_factory() as session: job = await jobs_repo.get(session, job_id) if job is None: raise HTTPException(status_code=404, detail="job not found") return job @router.get("/jobs", response_model=Job) async def lookup_job_by_correlation( client_id: Annotated[str, Query(...)], request_id: Annotated[str, Query(...)], session_factory: Annotated[ async_sessionmaker[AsyncSession], Depends(get_session_factory_dep) ], ) -> Job: async with session_factory() as session: job = await jobs_repo.get_by_correlation(session, client_id, request_id) if job is None: raise HTTPException(status_code=404, detail="job not found") return job @router.get("/healthz") async def healthz( response: Response, session_factory: Annotated[ async_sessionmaker[AsyncSession], Depends(get_session_factory_dep) ], probes: Annotated[Probes, Depends(get_probes)], ) -> HealthStatus: """Per spec §5: postgres / ollama / ocr; 200 iff all three == ok.""" postgres_state: Literal["ok", "fail"] = "fail" try: async def _probe() -> None: async with session_factory() as session: await session.execute(text("SELECT 1")) await asyncio.wait_for(_probe(), timeout=2.0) postgres_state = "ok" except Exception: postgres_state = "fail" try: ollama_state = probes.ollama() except Exception: ollama_state = "fail" try: ocr_state = probes.ocr() except Exception: ocr_state = "fail" body = HealthStatus( postgres=postgres_state, ollama=ollama_state, ocr=ocr_state ) if postgres_state != "ok" or ollama_state != "ok" or ocr_state != "ok": response.status_code = 503 return body @router.get("/metrics", response_model=MetricsResponse) async def metrics( session_factory: Annotated[ async_sessionmaker[AsyncSession], Depends(get_session_factory_dep) ], ) -> MetricsResponse: """Counters over the last 24h — plain JSON per spec §5.""" since = datetime.now(UTC) - timedelta(hours=24) async with session_factory() as session: pending = await session.scalar( select(func.count()).select_from(IxJob).where(IxJob.status == "pending") ) running = await session.scalar( select(func.count()).select_from(IxJob).where(IxJob.status == "running") ) done_24h = await session.scalar( select(func.count()) .select_from(IxJob) .where(IxJob.status == "done", IxJob.finished_at >= since) ) error_24h = await session.scalar( select(func.count()) .select_from(IxJob) .where(IxJob.status == "error", IxJob.finished_at >= since) ) # Per-use-case average seconds. ``request`` is JSONB, so we dig out # the use_case key via ->>. Only consider rows that both started and # finished in the window (can't compute elapsed otherwise). rows = ( await session.execute( text( "SELECT request->>'use_case' AS use_case, " "AVG(EXTRACT(EPOCH FROM (finished_at - started_at))) " "FROM ix_jobs " "WHERE status='done' AND finished_at IS NOT NULL " "AND started_at IS NOT NULL AND finished_at >= :since " "GROUP BY request->>'use_case'" ), {"since": since}, ) ).all() by_use_case = {row[0]: float(row[1]) for row in rows if row[0] is not None} return MetricsResponse( jobs_pending=int(pending or 0), jobs_running=int(running or 0), jobs_done_24h=int(done_24h or 0), jobs_error_24h=int(error_24h or 0), by_use_case_seconds=by_use_case, )