diff --git a/src/ix/adapters/pg_queue/__init__.py b/src/ix/adapters/pg_queue/__init__.py new file mode 100644 index 0000000..562a743 --- /dev/null +++ b/src/ix/adapters/pg_queue/__init__.py @@ -0,0 +1,15 @@ +"""Postgres queue adapter — ``LISTEN ix_jobs_new`` + 10 s fallback poll. + +This is a secondary transport: a direct-SQL writer can insert a row and +``NOTIFY ix_jobs_new, ''`` and the worker wakes up within the roundtrip +time rather than the 10 s fallback poll. The REST adapter doesn't need the +listener because the worker is already running in-process; this exists for +external callers who bypass the REST API. +""" + +from ix.adapters.pg_queue.listener import ( + PgQueueListener, + asyncpg_dsn_from_sqlalchemy_url, +) + +__all__ = ["PgQueueListener", "asyncpg_dsn_from_sqlalchemy_url"] diff --git a/src/ix/adapters/pg_queue/listener.py b/src/ix/adapters/pg_queue/listener.py new file mode 100644 index 0000000..d89a51a --- /dev/null +++ b/src/ix/adapters/pg_queue/listener.py @@ -0,0 +1,111 @@ +"""Dedicated asyncpg connection that LISTENs to ``ix_jobs_new``. + +We hold the connection *outside* the SQLAlchemy pool because SQLAlchemy's +asyncpg dialect doesn't expose the raw connection in a way that survives +the pool's checkout/checkin dance, and LISTEN needs a connection that +stays open for the full session to receive asynchronous notifications. + +The adapter contract the worker sees is a single coroutine-factory, +``wait_for_work(poll_seconds)``, which completes either when a NOTIFY +arrives or when ``poll_seconds`` elapse. The worker doesn't care which +woke it — it just goes back to its claim query. +""" + +from __future__ import annotations + +import asyncio +from urllib.parse import unquote, urlparse + +import asyncpg + + +def asyncpg_dsn_from_sqlalchemy_url(url: str) -> str: + """Strip the SQLAlchemy ``postgresql+asyncpg://`` prefix for raw asyncpg. + + asyncpg's connect() expects the plain ``postgres://user:pass@host/db`` + shape; the ``+driver`` segment SQLAlchemy adds confuses it. We also + percent-decode the password — asyncpg accepts the raw form but not the + pre-encoded ``%21`` passwords we sometimes use in dev. + """ + + parsed = urlparse(url) + scheme = parsed.scheme.split("+", 1)[0] + user = unquote(parsed.username) if parsed.username else "" + password = unquote(parsed.password) if parsed.password else "" + auth = "" + if user: + auth = f"{user}" + if password: + auth += f":{password}" + auth += "@" + netloc = parsed.hostname or "" + if parsed.port: + netloc += f":{parsed.port}" + return f"{scheme}://{auth}{netloc}{parsed.path}" + + +class PgQueueListener: + """Long-lived asyncpg connection that sets an event on each NOTIFY. + + The worker uses :meth:`wait_for_work` as its ``wait_for_work`` hook: + one call resolves when either a NOTIFY is received OR ``timeout`` + seconds elapse, whichever comes first. The event is cleared after each + resolution so subsequent waits don't see stale state. + """ + + def __init__(self, dsn: str, channel: str = "ix_jobs_new") -> None: + self._dsn = dsn + self._channel = channel + self._conn: asyncpg.Connection | None = None + self._event = asyncio.Event() + # Protect add_listener / remove_listener against concurrent + # start/stop — shouldn't happen in practice but a stray double-stop + # from a lifespan shutdown shouldn't raise ``listener not found``. + self._lock = asyncio.Lock() + + async def start(self) -> None: + async with self._lock: + if self._conn is not None: + return + self._conn = await asyncpg.connect(self._dsn) + await self._conn.add_listener(self._channel, self._on_notify) + + async def stop(self) -> None: + async with self._lock: + if self._conn is None: + return + try: + await self._conn.remove_listener(self._channel, self._on_notify) + finally: + await self._conn.close() + self._conn = None + + def _on_notify( + self, + connection: asyncpg.Connection, + pid: int, + channel: str, + payload: str, + ) -> None: + """asyncpg listener callback — signals the waiter.""" + + # We don't care about payload/pid/channel — any NOTIFY on our + # channel means "go check for pending rows". Keep the body tiny so + # asyncpg's single dispatch loop stays snappy. + self._event.set() + + async def wait_for_work(self, timeout: float) -> None: + """Resolve when a NOTIFY arrives or ``timeout`` seconds pass. + + We wait on the event with a timeout. ``asyncio.wait_for`` raises + :class:`asyncio.TimeoutError` on expiry; we swallow it because the + worker treats "either signal" identically. The event is cleared + after every wait so the next call starts fresh. + """ + + try: + await asyncio.wait_for(self._event.wait(), timeout=timeout) + except TimeoutError: + pass + finally: + self._event.clear() diff --git a/src/ix/app.py b/src/ix/app.py index 6517d4a..ab7a3d6 100644 --- a/src/ix/app.py +++ b/src/ix/app.py @@ -1,13 +1,15 @@ """FastAPI app factory + lifespan. ``create_app()`` wires the REST router on top of a lifespan that spawns the -worker loop (Task 3.5) and, optionally, the pg_queue listener (Task 3.6). -Tests that don't care about the worker call ``create_app(spawn_worker=False)`` -so the lifespan returns cleanly. +worker loop (Task 3.5) and the pg_queue listener (Task 3.6). Tests that +don't care about the worker call ``create_app(spawn_worker=False)`` so the +lifespan returns cleanly. The factory is parameterised (``spawn_worker``) instead of env-gated because pytest runs multiple app instances per session and we want the decision local -to each call, not inferred from ``IX_*`` variables. +to each call, not inferred from ``IX_*`` variables. The listener is also +gated on ``spawn_worker`` — the listener is only useful when a worker is +draining the queue, so the two share one flag. """ from __future__ import annotations @@ -26,25 +28,41 @@ def create_app(*, spawn_worker: bool = True) -> FastAPI: Parameters ---------- spawn_worker: - When True (default), the lifespan spawns the background worker task. - Integration tests that only exercise the REST adapter pass False so - jobs pile up as ``pending`` and the tests can assert on their state - without a racing worker mutating them. + When True (default), the lifespan spawns the background worker task + and the pg_queue listener. Integration tests that only exercise the + REST adapter pass False so jobs pile up as ``pending`` and the tests + can assert on their state without a racing worker mutating them. """ @asynccontextmanager async def lifespan(_app: FastAPI) -> AsyncIterator[None]: worker_task = None + listener = None if spawn_worker: - # Task 3.5 fills in the real spawn. We leave the hook so this - # file doesn't churn again when the worker lands — the adapter - # already imports `ix.worker.loop` once it exists. + # Pipeline factory + listener wiring live in Chunk 4's + # production entrypoint; keeping this path best-effort lets the + # lifespan still start even on a box where Ollama/Surya aren't + # available (the listener just gives us a passive 10 s poll). + try: + from ix.adapters.pg_queue.listener import ( + PgQueueListener, + asyncpg_dsn_from_sqlalchemy_url, + ) + from ix.config import get_config + + cfg = get_config() + listener = PgQueueListener( + dsn=asyncpg_dsn_from_sqlalchemy_url(cfg.postgres_url) + ) + await listener.start() + except Exception: + listener = None + try: from ix.worker.loop import spawn_worker_task worker_task = await spawn_worker_task(_app) except ImportError: - # Worker module isn't in place yet (Task 3.5 not merged). worker_task = None try: yield @@ -53,6 +71,9 @@ def create_app(*, spawn_worker: bool = True) -> FastAPI: worker_task.cancel() with suppress(Exception): await worker_task + if listener is not None: + with suppress(Exception): + await listener.stop() app = FastAPI(lifespan=lifespan, title="infoxtractor", version="0.1.0") app.include_router(rest_router) diff --git a/tests/integration/test_pg_queue_adapter.py b/tests/integration/test_pg_queue_adapter.py new file mode 100644 index 0000000..6fb8915 --- /dev/null +++ b/tests/integration/test_pg_queue_adapter.py @@ -0,0 +1,153 @@ +"""Integration tests for the PgQueueListener + worker integration (Task 3.6). + +Two scenarios: + +1. NOTIFY delivered — worker wakes within ~1 s and picks the row up. +2. Missed NOTIFY — the row still gets picked up by the fallback poll. + +Both run a real worker + listener against a live Postgres. We drive them via +``asyncio.gather`` + a "until done" watchdog. +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +from sqlalchemy import text + +from ix.adapters.pg_queue.listener import PgQueueListener, asyncpg_dsn_from_sqlalchemy_url +from ix.contracts.request import Context, RequestIX +from ix.pipeline.pipeline import Pipeline +from ix.pipeline.step import Step +from ix.store import jobs_repo +from ix.worker.loop import Worker + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + + +class _PassingStep(Step): + """Same minimal fake as test_worker_loop — keeps these suites independent.""" + + step_name = "fake_pass" + + async def validate(self, request_ix, response_ix): # type: ignore[no-untyped-def] + return True + + async def process(self, request_ix, response_ix): # type: ignore[no-untyped-def] + response_ix.use_case = request_ix.use_case + return response_ix + + +def _factory() -> Pipeline: + return Pipeline(steps=[_PassingStep()]) + + +async def _wait_for_status( + session_factory: async_sessionmaker[AsyncSession], + job_id, + target: str, + timeout_s: float, +) -> bool: + deadline = asyncio.get_event_loop().time() + timeout_s + while asyncio.get_event_loop().time() < deadline: + async with session_factory() as session: + job = await jobs_repo.get(session, job_id) + if job is not None and job.status == target: + return True + await asyncio.sleep(0.1) + return False + + +async def test_notify_wakes_worker_within_2s( + session_factory: async_sessionmaker[AsyncSession], + postgres_url: str, +) -> None: + """Direct INSERT + NOTIFY → worker picks it up fast (not via the poll).""" + + listener = PgQueueListener(dsn=asyncpg_dsn_from_sqlalchemy_url(postgres_url)) + await listener.start() + + worker = Worker( + session_factory=session_factory, + pipeline_factory=_factory, + # 60 s fallback poll — if we still find the row within 2 s it's + # because NOTIFY woke us, not the poll. + poll_interval_seconds=60.0, + max_running_seconds=3600, + wait_for_work=listener.wait_for_work, + ) + stop = asyncio.Event() + worker_task = asyncio.create_task(worker.run(stop)) + + # Give the worker one tick to reach the sleep_or_wake branch. + await asyncio.sleep(0.3) + + # Insert a pending row manually + NOTIFY — simulates a direct-SQL client + # like an external batch script. + request = RequestIX( + use_case="bank_statement_header", + ix_client_id="pgq", + request_id="notify-1", + context=Context(texts=["hi"]), + ) + async with session_factory() as session: + job = await jobs_repo.insert_pending(session, request, callback_url=None) + await session.commit() + async with session_factory() as session: + await session.execute( + text(f"NOTIFY ix_jobs_new, '{job.job_id}'") + ) + await session.commit() + + assert await _wait_for_status(session_factory, job.job_id, "done", 3.0), ( + "worker didn't pick up the NOTIFY'd row in time" + ) + + stop.set() + await worker_task + await listener.stop() + + +async def test_missed_notify_falls_back_to_poll( + session_factory: async_sessionmaker[AsyncSession], + postgres_url: str, +) -> None: + """Row lands without a NOTIFY; fallback poll still picks it up.""" + + listener = PgQueueListener(dsn=asyncpg_dsn_from_sqlalchemy_url(postgres_url)) + await listener.start() + + worker = Worker( + session_factory=session_factory, + pipeline_factory=_factory, + # Short poll so the fallback kicks in quickly — we need the test + # to finish in seconds, not the spec's 10 s. + poll_interval_seconds=0.5, + max_running_seconds=3600, + wait_for_work=listener.wait_for_work, + ) + stop = asyncio.Event() + worker_task = asyncio.create_task(worker.run(stop)) + + await asyncio.sleep(0.3) + + # Insert without NOTIFY: simulate a buggy writer. + request = RequestIX( + use_case="bank_statement_header", + ix_client_id="pgq", + request_id="missed-1", + context=Context(texts=["hi"]), + ) + async with session_factory() as session: + job = await jobs_repo.insert_pending(session, request, callback_url=None) + await session.commit() + + assert await _wait_for_status(session_factory, job.job_id, "done", 5.0), ( + "fallback poll didn't pick up the row" + ) + + stop.set() + await worker_task + await listener.stop()