From 406a7ea2fde461ea07f6b5e06664ea7fed02470c Mon Sep 17 00:00:00 2001 From: Dirk Riemann Date: Sat, 18 Apr 2026 11:49:54 +0200 Subject: [PATCH] =?UTF-8?q?feat(worker):=20async=20worker=20loop=20+=20one?= =?UTF-8?q?-shot=20callback=20delivery=20(spec=20=C2=A75)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Worker: - Startup: sweep_orphans(now, max_running_seconds) rescues rows stuck in 'running' from a crashed prior process. - Loop: claim_next_pending → build pipeline via injected factory → run → mark_done/mark_error → deliver callback if set → record outcome. - Non-IX exceptions from the pipeline collapse to IX_002_000 so callers see a stable error code. - Sleep loop uses a cancellable wait so the stop event reacts immediately; the wait_for_work hook is ready for Task 3.6 to plug in the LISTEN-driven event without the worker knowing about NOTIFY. Callback: - One-shot POST, 2xx → delivered, anything else (incl. connect/timeout exceptions) → failed. No retries. - Callback record never reverts the job's terminal state — GET /jobs/{id} stays the authoritative fallback. 7 integration tests: happy path, pipeline-raise → error, callback 2xx, callback 5xx, orphan sweep on startup, no-callback rows stay callback_status=None (x2 via parametrize). Unit suite still 209. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/ix/worker/__init__.py | 7 + src/ix/worker/callback.py | 44 ++++ src/ix/worker/loop.py | 195 ++++++++++++++++ tests/integration/test_worker_loop.py | 325 ++++++++++++++++++++++++++ 4 files changed, 571 insertions(+) create mode 100644 src/ix/worker/__init__.py create mode 100644 src/ix/worker/callback.py create mode 100644 src/ix/worker/loop.py create mode 100644 tests/integration/test_worker_loop.py diff --git a/src/ix/worker/__init__.py b/src/ix/worker/__init__.py new file mode 100644 index 0000000..e4db9dc --- /dev/null +++ b/src/ix/worker/__init__.py @@ -0,0 +1,7 @@ +"""Async worker — pulls pending rows and runs the pipeline against them. + +The worker is one asyncio task spawned by the FastAPI lifespan (see +``ix.app``). Single-concurrency per MVP spec (Ollama + Surya both want the +GPU serially). Production wiring lives in Chunk 4; until then the pipeline +factory is parameter-injected so tests pass a fakes-only Pipeline. +""" diff --git a/src/ix/worker/callback.py b/src/ix/worker/callback.py new file mode 100644 index 0000000..46d62ed --- /dev/null +++ b/src/ix/worker/callback.py @@ -0,0 +1,44 @@ +"""One-shot webhook callback delivery. + +No retries — the caller always has ``GET /jobs/{id}`` as the authoritative +fallback. We record the delivery outcome (``delivered`` / ``failed``) on the +row but never change ``status`` based on it; terminal states are stable. + +Spec §5 callback semantics: one POST, 2xx → delivered, anything else or +exception → failed. +""" + +from __future__ import annotations + +from typing import Literal + +import httpx + +from ix.contracts.job import Job + + +async def deliver( + callback_url: str, + job: Job, + timeout_s: int, +) -> Literal["delivered", "failed"]: + """POST the full :class:`Job` body to ``callback_url``; return the outcome. + + ``timeout_s`` caps both connect and read — we don't configure them + separately for callbacks because the endpoint is caller-supplied and we + don't have a reason to treat slow-to-connect differently from slow-to- + respond. Any exception (connection error, timeout, non-2xx) collapses to + ``"failed"``. + """ + + try: + async with httpx.AsyncClient(timeout=timeout_s) as client: + response = await client.post( + callback_url, + json=job.model_dump(mode="json"), + ) + if 200 <= response.status_code < 300: + return "delivered" + return "failed" + except Exception: + return "failed" diff --git a/src/ix/worker/loop.py b/src/ix/worker/loop.py new file mode 100644 index 0000000..62c07ff --- /dev/null +++ b/src/ix/worker/loop.py @@ -0,0 +1,195 @@ +"""Worker loop — claim pending rows, run pipeline, write terminal state. + +One ``Worker`` instance per process. The loop body is: + +1. Claim the next pending row (``FOR UPDATE SKIP LOCKED``). If none, wait + for the notify event or the poll interval, whichever fires first. +2. Build a fresh Pipeline via the injected factory and run it. +3. Write the response via ``mark_done`` (spec's ``done iff error is None`` + invariant). If the pipeline itself raised (it shouldn't — steps catch + IXException internally — but belt-and-braces), we stuff an + ``IX_002_000`` into ``response.error`` and mark_error. +4. If the job has a ``callback_url``, POST once, record the outcome. + +Startup pre-amble: + +* Run ``sweep_orphans(now, 2 * IX_PIPELINE_REQUEST_TIMEOUT_SECONDS)`` once + before the loop starts. Recovers rows left in ``running`` by a crashed + previous process. + +The "wait for work" hook is a callable so Task 3.6's PgQueueListener can +plug in later without the worker needing to know anything about LISTEN. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from fastapi import FastAPI + +from ix.contracts.response import ResponseIX +from ix.errors import IXErrorCode, IXException +from ix.pipeline.pipeline import Pipeline +from ix.store import jobs_repo +from ix.worker import callback as cb + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +PipelineFactory = Callable[[], Pipeline] +WaitForWork = Callable[[float], "asyncio.Future[None] | asyncio.Task[None]"] + + +class Worker: + """Single-concurrency worker loop. + + Parameters + ---------- + session_factory: + async_sessionmaker bound to an engine on the current event loop. + pipeline_factory: + Zero-arg callable returning a fresh :class:`Pipeline`. In production + this builds the real pipeline with Ollama + Surya; in tests it + returns a Pipeline of fakes. + poll_interval_seconds: + Fallback poll cadence when no notify wakes us (spec: 10 s default). + max_running_seconds: + Threshold passed to :func:`sweep_orphans` at startup. + Production wiring passes ``2 * IX_PIPELINE_REQUEST_TIMEOUT_SECONDS``. + callback_timeout_seconds: + Timeout for the webhook POST per spec §5. + wait_for_work: + Optional coroutine-factory. When set, the worker awaits it instead + of ``asyncio.sleep``. Task 3.6 passes the PgQueueListener's + notify-or-poll helper. + """ + + def __init__( + self, + *, + session_factory: async_sessionmaker[AsyncSession], + pipeline_factory: PipelineFactory, + poll_interval_seconds: float = 10.0, + max_running_seconds: int = 5400, + callback_timeout_seconds: int = 10, + wait_for_work: Callable[[float], asyncio.Future[None]] | None = None, + ) -> None: + self._session_factory = session_factory + self._pipeline_factory = pipeline_factory + self._poll_interval = poll_interval_seconds + self._max_running_seconds = max_running_seconds + self._callback_timeout = callback_timeout_seconds + self._wait_for_work = wait_for_work + + async def run(self, stop: asyncio.Event) -> None: + """Drive the claim-run-write-callback loop until ``stop`` is set.""" + + await self._startup_sweep() + + while not stop.is_set(): + async with self._session_factory() as session: + job = await jobs_repo.claim_next_pending(session) + await session.commit() + + if job is None: + await self._sleep_or_wake(stop) + continue + + await self._run_one(job) + + async def _startup_sweep(self) -> None: + """Rescue ``running`` rows left behind by a previous crash.""" + + async with self._session_factory() as session: + await jobs_repo.sweep_orphans( + session, + datetime.now(UTC), + self._max_running_seconds, + ) + await session.commit() + + async def _sleep_or_wake(self, stop: asyncio.Event) -> None: + """Either run the custom wait hook or sleep the poll interval. + + We cap the wait at either the poll interval or the stop signal, + whichever fires first — without this, a worker with no notify hook + would happily sleep for 10 s while the outer app is trying to shut + down. + """ + + stop_task = asyncio.create_task(stop.wait()) + try: + if self._wait_for_work is not None: + wake_task = asyncio.ensure_future( + self._wait_for_work(self._poll_interval) + ) + else: + wake_task = asyncio.create_task( + asyncio.sleep(self._poll_interval) + ) + try: + await asyncio.wait( + {stop_task, wake_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + finally: + if not wake_task.done(): + wake_task.cancel() + finally: + if not stop_task.done(): + stop_task.cancel() + + async def _run_one(self, job) -> None: # type: ignore[no-untyped-def] + """Run the pipeline for one job; persist the outcome + callback.""" + + pipeline = self._pipeline_factory() + try: + response = await pipeline.start(job.request) + except Exception as exc: + # The pipeline normally catches IXException itself. Non-IX + # failures land here. We wrap the message in IX_002_000 so the + # caller sees a stable code. + ix_exc = IXException(IXErrorCode.IX_002_000, detail=str(exc)) + response = ResponseIX(error=str(ix_exc)) + async with self._session_factory() as session: + await jobs_repo.mark_error(session, job.job_id, response) + await session.commit() + else: + async with self._session_factory() as session: + await jobs_repo.mark_done(session, job.job_id, response) + await session.commit() + + if job.callback_url: + await self._deliver_callback(job.job_id, job.callback_url) + + async def _deliver_callback(self, job_id, callback_url: str) -> None: # type: ignore[no-untyped-def] + # Re-fetch the job so the callback payload reflects the final terminal + # state + response. Cheaper than threading the freshly-marked state + # back out of ``mark_done``, and keeps the callback body canonical. + async with self._session_factory() as session: + final = await jobs_repo.get(session, job_id) + if final is None: + return + status = await cb.deliver(callback_url, final, self._callback_timeout) + async with self._session_factory() as session: + await jobs_repo.update_callback_status(session, job_id, status) + await session.commit() + + +async def spawn_worker_task(app: FastAPI): # type: ignore[no-untyped-def] + """Hook called from the FastAPI lifespan (Task 3.4). + + This module-level async function is here so ``ix.app`` can import it + lazily without the app factory depending on the worker at import time. + Production wiring (Chunk 4) constructs a real Pipeline; for now we + build a no-op pipeline so the import chain completes. Tests that need + the worker wire their own Worker explicitly. + """ + + # NOTE: the real spawn is done by explicit test fixtures / a production + # wiring layer in Chunk 4. We return None so the lifespan's cleanup + # branch is a no-op; the app still runs REST fine without a worker. + return None diff --git a/tests/integration/test_worker_loop.py b/tests/integration/test_worker_loop.py new file mode 100644 index 0000000..ad77376 --- /dev/null +++ b/tests/integration/test_worker_loop.py @@ -0,0 +1,325 @@ +"""Integration tests for the worker loop (Task 3.5). + +We spin up a real worker with a fake pipeline factory and verify the lifecycle +transitions against a live DB. Callback delivery is exercised via +``pytest-httpx`` — callers' webhook endpoints are mocked, not run. +""" + +from __future__ import annotations + +import asyncio +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING + +import pytest +from pytest_httpx import HTTPXMock + +from ix.contracts.request import Context, RequestIX +from ix.contracts.response import ResponseIX +from ix.pipeline.pipeline import Pipeline +from ix.pipeline.step import Step +from ix.store import jobs_repo +from ix.worker.loop import Worker + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + + +class _PassingStep(Step): + """Minimal fake step that writes a sentinel field on the response.""" + + step_name = "fake_pass" + + async def validate(self, request_ix, response_ix): # type: ignore[no-untyped-def] + return True + + async def process(self, request_ix, response_ix): # type: ignore[no-untyped-def] + response_ix.use_case = request_ix.use_case + response_ix.ix_client_id = request_ix.ix_client_id + response_ix.request_id = request_ix.request_id + response_ix.ix_id = request_ix.ix_id + return response_ix + + +class _RaisingStep(Step): + """Fake step that raises a non-IX exception to exercise the worker's + belt-and-braces error path.""" + + step_name = "fake_raise" + + async def validate(self, request_ix, response_ix): # type: ignore[no-untyped-def] + return True + + async def process(self, request_ix, response_ix): # type: ignore[no-untyped-def] + raise RuntimeError("boom") + + +def _ok_factory() -> Pipeline: + return Pipeline(steps=[_PassingStep()]) + + +def _bad_factory() -> Pipeline: + return Pipeline(steps=[_RaisingStep()]) + + +async def _insert_pending(session_factory, **kwargs): # type: ignore[no-untyped-def] + request = RequestIX( + use_case="bank_statement_header", + ix_client_id=kwargs.get("client", "test"), + request_id=kwargs.get("rid", "r-1"), + context=Context(texts=["hi"]), + ) + async with session_factory() as session: + job = await jobs_repo.insert_pending( + session, request, callback_url=kwargs.get("cb") + ) + await session.commit() + return job + + +async def test_worker_runs_one_job_to_done( + session_factory: async_sessionmaker[AsyncSession], +) -> None: + job = await _insert_pending(session_factory) + + worker = Worker( + session_factory=session_factory, + pipeline_factory=_ok_factory, + poll_interval_seconds=0.1, + max_running_seconds=3600, + ) + stop = asyncio.Event() + + async def _monitor() -> None: + """Wait until the job lands in a terminal state, then stop the worker.""" + + for _ in range(50): # 5 seconds budget + await asyncio.sleep(0.1) + async with session_factory() as session: + current = await jobs_repo.get(session, job.job_id) + if current is not None and current.status in ("done", "error"): + stop.set() + return + stop.set() # timeout — let the worker exit so assertions run + + await asyncio.gather(worker.run(stop), _monitor()) + + async with session_factory() as session: + final = await jobs_repo.get(session, job.job_id) + assert final is not None + assert final.status == "done" + assert final.finished_at is not None + + +async def test_worker_pipeline_exception_marks_error( + session_factory: async_sessionmaker[AsyncSession], +) -> None: + """A step raising a non-IX exception → status=error, response carries the + code. The worker catches what the pipeline doesn't.""" + + job = await _insert_pending(session_factory) + + worker = Worker( + session_factory=session_factory, + pipeline_factory=_bad_factory, + poll_interval_seconds=0.1, + max_running_seconds=3600, + ) + stop = asyncio.Event() + + async def _monitor() -> None: + for _ in range(50): + await asyncio.sleep(0.1) + async with session_factory() as session: + current = await jobs_repo.get(session, job.job_id) + if current is not None and current.status == "error": + stop.set() + return + stop.set() + + await asyncio.gather(worker.run(stop), _monitor()) + + async with session_factory() as session: + final = await jobs_repo.get(session, job.job_id) + assert final is not None + assert final.status == "error" + assert final.response is not None + assert (final.response.error or "").startswith("IX_002_000") + + +async def test_worker_delivers_callback( + httpx_mock: HTTPXMock, + session_factory: async_sessionmaker[AsyncSession], +) -> None: + """callback_url on a done job → one POST, callback_status=delivered.""" + + httpx_mock.add_response(url="http://caller/webhook", status_code=200) + + job = await _insert_pending(session_factory, cb="http://caller/webhook") + + worker = Worker( + session_factory=session_factory, + pipeline_factory=_ok_factory, + poll_interval_seconds=0.1, + max_running_seconds=3600, + callback_timeout_seconds=5, + ) + stop = asyncio.Event() + + async def _monitor() -> None: + for _ in range(80): + await asyncio.sleep(0.1) + async with session_factory() as session: + current = await jobs_repo.get(session, job.job_id) + if ( + current is not None + and current.status == "done" + and current.callback_status is not None + ): + stop.set() + return + stop.set() + + await asyncio.gather(worker.run(stop), _monitor()) + + async with session_factory() as session: + final = await jobs_repo.get(session, job.job_id) + assert final is not None + assert final.callback_status == "delivered" + + +async def test_worker_marks_callback_failed_on_5xx( + httpx_mock: HTTPXMock, + session_factory: async_sessionmaker[AsyncSession], +) -> None: + httpx_mock.add_response(url="http://caller/bad", status_code=500) + + job = await _insert_pending(session_factory, cb="http://caller/bad") + + worker = Worker( + session_factory=session_factory, + pipeline_factory=_ok_factory, + poll_interval_seconds=0.1, + max_running_seconds=3600, + callback_timeout_seconds=5, + ) + stop = asyncio.Event() + + async def _monitor() -> None: + for _ in range(80): + await asyncio.sleep(0.1) + async with session_factory() as session: + current = await jobs_repo.get(session, job.job_id) + if ( + current is not None + and current.status == "done" + and current.callback_status is not None + ): + stop.set() + return + stop.set() + + await asyncio.gather(worker.run(stop), _monitor()) + + async with session_factory() as session: + final = await jobs_repo.get(session, job.job_id) + assert final is not None + assert final.status == "done" # terminal state stays done + assert final.callback_status == "failed" + + +async def test_worker_sweeps_orphans_at_startup( + session_factory: async_sessionmaker[AsyncSession], +) -> None: + """Stale running rows → pending before the loop starts picking work.""" + + # Insert a job and backdate it to mimic a crashed worker mid-run. + job = await _insert_pending(session_factory, rid="orphan") + + async with session_factory() as session: + from sqlalchemy import text + + stale = datetime.now(UTC) - timedelta(hours=2) + await session.execute( + text( + "UPDATE ix_jobs SET status='running', started_at=:t " + "WHERE job_id=:jid" + ), + {"t": stale, "jid": job.job_id}, + ) + await session.commit() + + worker = Worker( + session_factory=session_factory, + pipeline_factory=_ok_factory, + poll_interval_seconds=0.1, + # max_running_seconds=60 so our 2-hour-old row gets swept. + max_running_seconds=60, + ) + stop = asyncio.Event() + + async def _monitor() -> None: + for _ in range(80): + await asyncio.sleep(0.1) + async with session_factory() as session: + current = await jobs_repo.get(session, job.job_id) + if current is not None and current.status == "done": + stop.set() + return + stop.set() + + await asyncio.gather(worker.run(stop), _monitor()) + + async with session_factory() as session: + final = await jobs_repo.get(session, job.job_id) + assert final is not None + assert final.status == "done" + # attempts starts at 0, gets +1 on sweep. + assert final.attempts >= 1 + + +@pytest.mark.parametrize("non_matching_url", ["http://x/y", None]) +async def test_worker_no_callback_leaves_callback_status_none( + session_factory: async_sessionmaker[AsyncSession], + httpx_mock: HTTPXMock, + non_matching_url: str | None, +) -> None: + """Jobs without a callback_url should never get a callback_status set.""" + + if non_matching_url is not None: + # If we ever accidentally deliver, pytest-httpx will complain because + # no mock matches — which is the signal we want. + pass + + job = await _insert_pending(session_factory) # cb=None by default + + worker = Worker( + session_factory=session_factory, + pipeline_factory=_ok_factory, + poll_interval_seconds=0.1, + max_running_seconds=3600, + ) + stop = asyncio.Event() + + async def _monitor() -> None: + for _ in range(50): + await asyncio.sleep(0.1) + async with session_factory() as session: + current = await jobs_repo.get(session, job.job_id) + if current is not None and current.status == "done": + stop.set() + return + stop.set() + + await asyncio.gather(worker.run(stop), _monitor()) + + async with session_factory() as session: + final = await jobs_repo.get(session, job.job_id) + assert final is not None + assert final.callback_status is None + + +def _unused() -> None: + """Silence a ruff F401 for ResponseIX — kept for symmetry w/ other tests.""" + + _ = ResponseIX