"""Worker loop — claim pending rows, run pipeline, write terminal state. One ``Worker`` instance per process. The loop body is: 1. Claim the next pending row (``FOR UPDATE SKIP LOCKED``). If none, wait for the notify event or the poll interval, whichever fires first. 2. Build a fresh Pipeline via the injected factory and run it. 3. Write the response via ``mark_done`` (spec's ``done iff error is None`` invariant). If the pipeline itself raised (it shouldn't — steps catch IXException internally — but belt-and-braces), we stuff an ``IX_002_000`` into ``response.error`` and mark_error. 4. If the job has a ``callback_url``, POST once, record the outcome. Startup pre-amble: * Run ``sweep_orphans(now, 2 * IX_PIPELINE_REQUEST_TIMEOUT_SECONDS)`` once before the loop starts. Recovers rows left in ``running`` by a crashed previous process. The "wait for work" hook is a callable so Task 3.6's PgQueueListener can plug in later without the worker needing to know anything about LISTEN. """ from __future__ import annotations import asyncio from collections.abc import Callable from datetime import UTC, datetime from typing import TYPE_CHECKING from ix.contracts.response import ResponseIX from ix.errors import IXErrorCode, IXException from ix.pipeline.pipeline import Pipeline from ix.store import jobs_repo from ix.worker import callback as cb if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker PipelineFactory = Callable[[], Pipeline] WaitForWork = Callable[[float], "asyncio.Future[None] | asyncio.Task[None]"] class Worker: """Single-concurrency worker loop. Parameters ---------- session_factory: async_sessionmaker bound to an engine on the current event loop. pipeline_factory: Zero-arg callable returning a fresh :class:`Pipeline`. In production this builds the real pipeline with Ollama + Surya; in tests it returns a Pipeline of fakes. poll_interval_seconds: Fallback poll cadence when no notify wakes us (spec: 10 s default). max_running_seconds: Threshold passed to :func:`sweep_orphans` at startup. Production wiring passes ``2 * IX_PIPELINE_REQUEST_TIMEOUT_SECONDS``. callback_timeout_seconds: Timeout for the webhook POST per spec §5. wait_for_work: Optional coroutine-factory. When set, the worker awaits it instead of ``asyncio.sleep``. Task 3.6 passes the PgQueueListener's notify-or-poll helper. """ def __init__( self, *, session_factory: async_sessionmaker[AsyncSession], pipeline_factory: PipelineFactory, poll_interval_seconds: float = 10.0, max_running_seconds: int = 5400, callback_timeout_seconds: int = 10, wait_for_work: Callable[[float], asyncio.Future[None]] | None = None, ) -> None: self._session_factory = session_factory self._pipeline_factory = pipeline_factory self._poll_interval = poll_interval_seconds self._max_running_seconds = max_running_seconds self._callback_timeout = callback_timeout_seconds self._wait_for_work = wait_for_work async def run(self, stop: asyncio.Event) -> None: """Drive the claim-run-write-callback loop until ``stop`` is set.""" await self._startup_sweep() while not stop.is_set(): async with self._session_factory() as session: job = await jobs_repo.claim_next_pending(session) await session.commit() if job is None: await self._sleep_or_wake(stop) continue await self._run_one(job) async def _startup_sweep(self) -> None: """Rescue ``running`` rows left behind by a previous crash.""" async with self._session_factory() as session: await jobs_repo.sweep_orphans( session, datetime.now(UTC), self._max_running_seconds, ) await session.commit() async def _sleep_or_wake(self, stop: asyncio.Event) -> None: """Either run the custom wait hook or sleep the poll interval. We cap the wait at either the poll interval or the stop signal, whichever fires first — without this, a worker with no notify hook would happily sleep for 10 s while the outer app is trying to shut down. """ stop_task = asyncio.create_task(stop.wait()) try: if self._wait_for_work is not None: wake_task = asyncio.ensure_future( self._wait_for_work(self._poll_interval) ) else: wake_task = asyncio.create_task( asyncio.sleep(self._poll_interval) ) try: await asyncio.wait( {stop_task, wake_task}, return_when=asyncio.FIRST_COMPLETED, ) finally: if not wake_task.done(): wake_task.cancel() finally: if not stop_task.done(): stop_task.cancel() async def _run_one(self, job) -> None: # type: ignore[no-untyped-def] """Run the pipeline for one job; persist the outcome + callback.""" pipeline = self._pipeline_factory() try: response = await pipeline.start(job.request) except Exception as exc: # The pipeline normally catches IXException itself. Non-IX # failures land here. We wrap the message in IX_002_000 so the # caller sees a stable code. ix_exc = IXException(IXErrorCode.IX_002_000, detail=str(exc)) response = ResponseIX(error=str(ix_exc)) async with self._session_factory() as session: await jobs_repo.mark_error(session, job.job_id, response) await session.commit() else: async with self._session_factory() as session: await jobs_repo.mark_done(session, job.job_id, response) await session.commit() if job.callback_url: await self._deliver_callback(job.job_id, job.callback_url) async def _deliver_callback(self, job_id, callback_url: str) -> None: # type: ignore[no-untyped-def] # Re-fetch the job so the callback payload reflects the final terminal # state + response. Cheaper than threading the freshly-marked state # back out of ``mark_done``, and keeps the callback body canonical. async with self._session_factory() as session: final = await jobs_repo.get(session, job_id) if final is None: return status = await cb.deliver(callback_url, final, self._callback_timeout) async with self._session_factory() as session: await jobs_repo.update_callback_status(session, job_id, status) await session.commit()