Merge pull request 'feat(pg-queue): LISTEN ix_jobs_new + 10s fallback poll' (#23) from feat/pg-queue-adapter into main
All checks were successful
tests / test (push) Successful in 1m11s
All checks were successful
tests / test (push) Successful in 1m11s
This commit is contained in:
commit
6183b9c886
4 changed files with 312 additions and 12 deletions
15
src/ix/adapters/pg_queue/__init__.py
Normal file
15
src/ix/adapters/pg_queue/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
"""Postgres queue adapter — ``LISTEN ix_jobs_new`` + 10 s fallback poll.
|
||||||
|
|
||||||
|
This is a secondary transport: a direct-SQL writer can insert a row and
|
||||||
|
``NOTIFY ix_jobs_new, '<job_id>'`` and the worker wakes up within the roundtrip
|
||||||
|
time rather than the 10 s fallback poll. The REST adapter doesn't need the
|
||||||
|
listener because the worker is already running in-process; this exists for
|
||||||
|
external callers who bypass the REST API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from ix.adapters.pg_queue.listener import (
|
||||||
|
PgQueueListener,
|
||||||
|
asyncpg_dsn_from_sqlalchemy_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["PgQueueListener", "asyncpg_dsn_from_sqlalchemy_url"]
|
||||||
111
src/ix/adapters/pg_queue/listener.py
Normal file
111
src/ix/adapters/pg_queue/listener.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""Dedicated asyncpg connection that LISTENs to ``ix_jobs_new``.
|
||||||
|
|
||||||
|
We hold the connection *outside* the SQLAlchemy pool because SQLAlchemy's
|
||||||
|
asyncpg dialect doesn't expose the raw connection in a way that survives
|
||||||
|
the pool's checkout/checkin dance, and LISTEN needs a connection that
|
||||||
|
stays open for the full session to receive asynchronous notifications.
|
||||||
|
|
||||||
|
The adapter contract the worker sees is a single coroutine-factory,
|
||||||
|
``wait_for_work(poll_seconds)``, which completes either when a NOTIFY
|
||||||
|
arrives or when ``poll_seconds`` elapse. The worker doesn't care which
|
||||||
|
woke it — it just goes back to its claim query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
|
||||||
|
def asyncpg_dsn_from_sqlalchemy_url(url: str) -> str:
|
||||||
|
"""Strip the SQLAlchemy ``postgresql+asyncpg://`` prefix for raw asyncpg.
|
||||||
|
|
||||||
|
asyncpg's connect() expects the plain ``postgres://user:pass@host/db``
|
||||||
|
shape; the ``+driver`` segment SQLAlchemy adds confuses it. We also
|
||||||
|
percent-decode the password — asyncpg accepts the raw form but not the
|
||||||
|
pre-encoded ``%21`` passwords we sometimes use in dev.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parsed = urlparse(url)
|
||||||
|
scheme = parsed.scheme.split("+", 1)[0]
|
||||||
|
user = unquote(parsed.username) if parsed.username else ""
|
||||||
|
password = unquote(parsed.password) if parsed.password else ""
|
||||||
|
auth = ""
|
||||||
|
if user:
|
||||||
|
auth = f"{user}"
|
||||||
|
if password:
|
||||||
|
auth += f":{password}"
|
||||||
|
auth += "@"
|
||||||
|
netloc = parsed.hostname or ""
|
||||||
|
if parsed.port:
|
||||||
|
netloc += f":{parsed.port}"
|
||||||
|
return f"{scheme}://{auth}{netloc}{parsed.path}"
|
||||||
|
|
||||||
|
|
||||||
|
class PgQueueListener:
|
||||||
|
"""Long-lived asyncpg connection that sets an event on each NOTIFY.
|
||||||
|
|
||||||
|
The worker uses :meth:`wait_for_work` as its ``wait_for_work`` hook:
|
||||||
|
one call resolves when either a NOTIFY is received OR ``timeout``
|
||||||
|
seconds elapse, whichever comes first. The event is cleared after each
|
||||||
|
resolution so subsequent waits don't see stale state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dsn: str, channel: str = "ix_jobs_new") -> None:
|
||||||
|
self._dsn = dsn
|
||||||
|
self._channel = channel
|
||||||
|
self._conn: asyncpg.Connection | None = None
|
||||||
|
self._event = asyncio.Event()
|
||||||
|
# Protect add_listener / remove_listener against concurrent
|
||||||
|
# start/stop — shouldn't happen in practice but a stray double-stop
|
||||||
|
# from a lifespan shutdown shouldn't raise ``listener not found``.
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
if self._conn is not None:
|
||||||
|
return
|
||||||
|
self._conn = await asyncpg.connect(self._dsn)
|
||||||
|
await self._conn.add_listener(self._channel, self._on_notify)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
if self._conn is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await self._conn.remove_listener(self._channel, self._on_notify)
|
||||||
|
finally:
|
||||||
|
await self._conn.close()
|
||||||
|
self._conn = None
|
||||||
|
|
||||||
|
def _on_notify(
|
||||||
|
self,
|
||||||
|
connection: asyncpg.Connection,
|
||||||
|
pid: int,
|
||||||
|
channel: str,
|
||||||
|
payload: str,
|
||||||
|
) -> None:
|
||||||
|
"""asyncpg listener callback — signals the waiter."""
|
||||||
|
|
||||||
|
# We don't care about payload/pid/channel — any NOTIFY on our
|
||||||
|
# channel means "go check for pending rows". Keep the body tiny so
|
||||||
|
# asyncpg's single dispatch loop stays snappy.
|
||||||
|
self._event.set()
|
||||||
|
|
||||||
|
async def wait_for_work(self, timeout: float) -> None:
|
||||||
|
"""Resolve when a NOTIFY arrives or ``timeout`` seconds pass.
|
||||||
|
|
||||||
|
We wait on the event with a timeout. ``asyncio.wait_for`` raises
|
||||||
|
:class:`asyncio.TimeoutError` on expiry; we swallow it because the
|
||||||
|
worker treats "either signal" identically. The event is cleared
|
||||||
|
after every wait so the next call starts fresh.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._event.wait(), timeout=timeout)
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._event.clear()
|
||||||
|
|
@ -1,13 +1,15 @@
|
||||||
"""FastAPI app factory + lifespan.
|
"""FastAPI app factory + lifespan.
|
||||||
|
|
||||||
``create_app()`` wires the REST router on top of a lifespan that spawns the
|
``create_app()`` wires the REST router on top of a lifespan that spawns the
|
||||||
worker loop (Task 3.5) and, optionally, the pg_queue listener (Task 3.6).
|
worker loop (Task 3.5) and the pg_queue listener (Task 3.6). Tests that
|
||||||
Tests that don't care about the worker call ``create_app(spawn_worker=False)``
|
don't care about the worker call ``create_app(spawn_worker=False)`` so the
|
||||||
so the lifespan returns cleanly.
|
lifespan returns cleanly.
|
||||||
|
|
||||||
The factory is parameterised (``spawn_worker``) instead of env-gated because
|
The factory is parameterised (``spawn_worker``) instead of env-gated because
|
||||||
pytest runs multiple app instances per session and we want the decision local
|
pytest runs multiple app instances per session and we want the decision local
|
||||||
to each call, not inferred from ``IX_*`` variables.
|
to each call, not inferred from ``IX_*`` variables. The listener is also
|
||||||
|
gated on ``spawn_worker`` — the listener is only useful when a worker is
|
||||||
|
draining the queue, so the two share one flag.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -26,25 +28,41 @@ def create_app(*, spawn_worker: bool = True) -> FastAPI:
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
spawn_worker:
|
spawn_worker:
|
||||||
When True (default), the lifespan spawns the background worker task.
|
When True (default), the lifespan spawns the background worker task
|
||||||
Integration tests that only exercise the REST adapter pass False so
|
and the pg_queue listener. Integration tests that only exercise the
|
||||||
jobs pile up as ``pending`` and the tests can assert on their state
|
REST adapter pass False so jobs pile up as ``pending`` and the tests
|
||||||
without a racing worker mutating them.
|
can assert on their state without a racing worker mutating them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||||
worker_task = None
|
worker_task = None
|
||||||
|
listener = None
|
||||||
if spawn_worker:
|
if spawn_worker:
|
||||||
# Task 3.5 fills in the real spawn. We leave the hook so this
|
# Pipeline factory + listener wiring live in Chunk 4's
|
||||||
# file doesn't churn again when the worker lands — the adapter
|
# production entrypoint; keeping this path best-effort lets the
|
||||||
# already imports `ix.worker.loop` once it exists.
|
# lifespan still start even on a box where Ollama/Surya aren't
|
||||||
|
# available (the listener just gives us a passive 10 s poll).
|
||||||
|
try:
|
||||||
|
from ix.adapters.pg_queue.listener import (
|
||||||
|
PgQueueListener,
|
||||||
|
asyncpg_dsn_from_sqlalchemy_url,
|
||||||
|
)
|
||||||
|
from ix.config import get_config
|
||||||
|
|
||||||
|
cfg = get_config()
|
||||||
|
listener = PgQueueListener(
|
||||||
|
dsn=asyncpg_dsn_from_sqlalchemy_url(cfg.postgres_url)
|
||||||
|
)
|
||||||
|
await listener.start()
|
||||||
|
except Exception:
|
||||||
|
listener = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ix.worker.loop import spawn_worker_task
|
from ix.worker.loop import spawn_worker_task
|
||||||
|
|
||||||
worker_task = await spawn_worker_task(_app)
|
worker_task = await spawn_worker_task(_app)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Worker module isn't in place yet (Task 3.5 not merged).
|
|
||||||
worker_task = None
|
worker_task = None
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
|
|
@ -53,6 +71,9 @@ def create_app(*, spawn_worker: bool = True) -> FastAPI:
|
||||||
worker_task.cancel()
|
worker_task.cancel()
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
await worker_task
|
await worker_task
|
||||||
|
if listener is not None:
|
||||||
|
with suppress(Exception):
|
||||||
|
await listener.stop()
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan, title="infoxtractor", version="0.1.0")
|
app = FastAPI(lifespan=lifespan, title="infoxtractor", version="0.1.0")
|
||||||
app.include_router(rest_router)
|
app.include_router(rest_router)
|
||||||
|
|
|
||||||
153
tests/integration/test_pg_queue_adapter.py
Normal file
153
tests/integration/test_pg_queue_adapter.py
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
"""Integration tests for the PgQueueListener + worker integration (Task 3.6).
|
||||||
|
|
||||||
|
Two scenarios:
|
||||||
|
|
||||||
|
1. NOTIFY delivered — worker wakes within ~1 s and picks the row up.
|
||||||
|
2. Missed NOTIFY — the row still gets picked up by the fallback poll.
|
||||||
|
|
||||||
|
Both run a real worker + listener against a live Postgres. We drive them via
|
||||||
|
``asyncio.gather`` + a "until done" watchdog.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from ix.adapters.pg_queue.listener import PgQueueListener, asyncpg_dsn_from_sqlalchemy_url
|
||||||
|
from ix.contracts.request import Context, RequestIX
|
||||||
|
from ix.pipeline.pipeline import Pipeline
|
||||||
|
from ix.pipeline.step import Step
|
||||||
|
from ix.store import jobs_repo
|
||||||
|
from ix.worker.loop import Worker
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
|
||||||
|
class _PassingStep(Step):
|
||||||
|
"""Same minimal fake as test_worker_loop — keeps these suites independent."""
|
||||||
|
|
||||||
|
step_name = "fake_pass"
|
||||||
|
|
||||||
|
async def validate(self, request_ix, response_ix): # type: ignore[no-untyped-def]
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def process(self, request_ix, response_ix): # type: ignore[no-untyped-def]
|
||||||
|
response_ix.use_case = request_ix.use_case
|
||||||
|
return response_ix
|
||||||
|
|
||||||
|
|
||||||
|
def _factory() -> Pipeline:
|
||||||
|
return Pipeline(steps=[_PassingStep()])
|
||||||
|
|
||||||
|
|
||||||
|
async def _wait_for_status(
|
||||||
|
session_factory: async_sessionmaker[AsyncSession],
|
||||||
|
job_id,
|
||||||
|
target: str,
|
||||||
|
timeout_s: float,
|
||||||
|
) -> bool:
|
||||||
|
deadline = asyncio.get_event_loop().time() + timeout_s
|
||||||
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
|
async with session_factory() as session:
|
||||||
|
job = await jobs_repo.get(session, job_id)
|
||||||
|
if job is not None and job.status == target:
|
||||||
|
return True
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_notify_wakes_worker_within_2s(
|
||||||
|
session_factory: async_sessionmaker[AsyncSession],
|
||||||
|
postgres_url: str,
|
||||||
|
) -> None:
|
||||||
|
"""Direct INSERT + NOTIFY → worker picks it up fast (not via the poll)."""
|
||||||
|
|
||||||
|
listener = PgQueueListener(dsn=asyncpg_dsn_from_sqlalchemy_url(postgres_url))
|
||||||
|
await listener.start()
|
||||||
|
|
||||||
|
worker = Worker(
|
||||||
|
session_factory=session_factory,
|
||||||
|
pipeline_factory=_factory,
|
||||||
|
# 60 s fallback poll — if we still find the row within 2 s it's
|
||||||
|
# because NOTIFY woke us, not the poll.
|
||||||
|
poll_interval_seconds=60.0,
|
||||||
|
max_running_seconds=3600,
|
||||||
|
wait_for_work=listener.wait_for_work,
|
||||||
|
)
|
||||||
|
stop = asyncio.Event()
|
||||||
|
worker_task = asyncio.create_task(worker.run(stop))
|
||||||
|
|
||||||
|
# Give the worker one tick to reach the sleep_or_wake branch.
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
# Insert a pending row manually + NOTIFY — simulates a direct-SQL client
|
||||||
|
# like an external batch script.
|
||||||
|
request = RequestIX(
|
||||||
|
use_case="bank_statement_header",
|
||||||
|
ix_client_id="pgq",
|
||||||
|
request_id="notify-1",
|
||||||
|
context=Context(texts=["hi"]),
|
||||||
|
)
|
||||||
|
async with session_factory() as session:
|
||||||
|
job = await jobs_repo.insert_pending(session, request, callback_url=None)
|
||||||
|
await session.commit()
|
||||||
|
async with session_factory() as session:
|
||||||
|
await session.execute(
|
||||||
|
text(f"NOTIFY ix_jobs_new, '{job.job_id}'")
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
assert await _wait_for_status(session_factory, job.job_id, "done", 3.0), (
|
||||||
|
"worker didn't pick up the NOTIFY'd row in time"
|
||||||
|
)
|
||||||
|
|
||||||
|
stop.set()
|
||||||
|
await worker_task
|
||||||
|
await listener.stop()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_missed_notify_falls_back_to_poll(
|
||||||
|
session_factory: async_sessionmaker[AsyncSession],
|
||||||
|
postgres_url: str,
|
||||||
|
) -> None:
|
||||||
|
"""Row lands without a NOTIFY; fallback poll still picks it up."""
|
||||||
|
|
||||||
|
listener = PgQueueListener(dsn=asyncpg_dsn_from_sqlalchemy_url(postgres_url))
|
||||||
|
await listener.start()
|
||||||
|
|
||||||
|
worker = Worker(
|
||||||
|
session_factory=session_factory,
|
||||||
|
pipeline_factory=_factory,
|
||||||
|
# Short poll so the fallback kicks in quickly — we need the test
|
||||||
|
# to finish in seconds, not the spec's 10 s.
|
||||||
|
poll_interval_seconds=0.5,
|
||||||
|
max_running_seconds=3600,
|
||||||
|
wait_for_work=listener.wait_for_work,
|
||||||
|
)
|
||||||
|
stop = asyncio.Event()
|
||||||
|
worker_task = asyncio.create_task(worker.run(stop))
|
||||||
|
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
# Insert without NOTIFY: simulate a buggy writer.
|
||||||
|
request = RequestIX(
|
||||||
|
use_case="bank_statement_header",
|
||||||
|
ix_client_id="pgq",
|
||||||
|
request_id="missed-1",
|
||||||
|
context=Context(texts=["hi"]),
|
||||||
|
)
|
||||||
|
async with session_factory() as session:
|
||||||
|
job = await jobs_repo.insert_pending(session, request, callback_url=None)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
assert await _wait_for_status(session_factory, job.job_id, "done", 5.0), (
|
||||||
|
"fallback poll didn't pick up the row"
|
||||||
|
)
|
||||||
|
|
||||||
|
stop.set()
|
||||||
|
await worker_task
|
||||||
|
await listener.stop()
|
||||||
Loading…
Reference in a new issue