The shared postgis container is bound to 127.0.0.1 on the host (security hardening, infrastructure §T12). Ollama is similarly LAN-hardened. The previous `host.docker.internal + extra_hosts: host-gateway` approach points at the bridge gateway IP, not loopback, so the container couldn't reach either service. Switch to `network_mode: host` (same pattern goldstein uses) and update the default IX_POSTGRES_URL / IX_OLLAMA_URL to 127.0.0.1. Keep the GPU reservation block; drop the now-meaningless ports: declaration (host mode publishes directly). AppConfig defaults + .env.example + test_config assertions + inline docstring examples all follow. Caught on fourth deploy attempt. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
203 lines
7.2 KiB
Python
203 lines
7.2 KiB
Python
"""OllamaClient — real :class:`GenAIClient` implementation (spec §6 GenAIStep).
|
|
|
|
Wraps the Ollama ``/api/chat`` structured-output endpoint. Per spec:
|
|
|
|
* POST ``{base_url}/api/chat`` with ``format = <pydantic JSON schema>``,
|
|
``stream = false``, and ``options`` carrying provider-neutral knobs
|
|
(``temperature`` mapped, ``reasoning_effort`` dropped — Ollama ignores it).
|
|
* Messages are passed through. Content-parts lists (``[{"type":"text",...}]``)
|
|
are joined to a single string because MVP models (``gpt-oss:20b`` /
|
|
``qwen2.5:32b``) don't accept native content-parts.
|
|
* Per-call timeout is enforced via ``httpx``. A connection refusal, read
|
|
timeout, or 5xx maps to ``IX_002_000``. A 2xx whose ``message.content`` is
|
|
not valid JSON for the schema maps to ``IX_002_001``.
|
|
|
|
``selfcheck()`` targets ``/api/tags`` with a fixed 5 s timeout and is what
|
|
``/healthz`` consumes.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Literal
|
|
|
|
import httpx
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from ix.errors import IXErrorCode, IXException
|
|
from ix.genai.client import GenAIInvocationResult, GenAIUsage
|
|
|
|
_OLLAMA_TAGS_TIMEOUT_S: float = 5.0
|
|
_BODY_SNIPPET_MAX_CHARS: int = 240
|
|
|
|
|
|
class OllamaClient:
|
|
"""Async Ollama backend satisfying :class:`~ix.genai.client.GenAIClient`.
|
|
|
|
Parameters
|
|
----------
|
|
base_url:
|
|
Root URL of the Ollama server (e.g. ``http://127.0.0.1:11434``).
|
|
Trailing slashes are stripped.
|
|
per_call_timeout_s:
|
|
Hard per-call timeout for ``/api/chat``. Spec default: 1500 s.
|
|
"""
|
|
|
|
def __init__(self, base_url: str, per_call_timeout_s: float) -> None:
|
|
self._base_url = base_url.rstrip("/")
|
|
self._per_call_timeout_s = per_call_timeout_s
|
|
|
|
async def invoke(
|
|
self,
|
|
request_kwargs: dict[str, Any],
|
|
response_schema: type[BaseModel],
|
|
) -> GenAIInvocationResult:
|
|
"""Run one structured-output chat call; parse into ``response_schema``."""
|
|
|
|
body = self._translate_request(request_kwargs, response_schema)
|
|
url = f"{self._base_url}/api/chat"
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self._per_call_timeout_s) as http:
|
|
resp = await http.post(url, json=body)
|
|
except httpx.HTTPError as exc:
|
|
raise IXException(
|
|
IXErrorCode.IX_002_000,
|
|
detail=f"ollama {exc.__class__.__name__}: {exc}",
|
|
) from exc
|
|
except (ConnectionError, TimeoutError) as exc: # pragma: no cover - httpx wraps these
|
|
raise IXException(
|
|
IXErrorCode.IX_002_000,
|
|
detail=f"ollama {exc.__class__.__name__}: {exc}",
|
|
) from exc
|
|
|
|
if resp.status_code >= 500:
|
|
raise IXException(
|
|
IXErrorCode.IX_002_000,
|
|
detail=(
|
|
f"ollama HTTP {resp.status_code}: "
|
|
f"{resp.text[:_BODY_SNIPPET_MAX_CHARS]}"
|
|
),
|
|
)
|
|
if resp.status_code >= 400:
|
|
raise IXException(
|
|
IXErrorCode.IX_002_000,
|
|
detail=(
|
|
f"ollama HTTP {resp.status_code}: "
|
|
f"{resp.text[:_BODY_SNIPPET_MAX_CHARS]}"
|
|
),
|
|
)
|
|
|
|
try:
|
|
payload = resp.json()
|
|
except ValueError as exc:
|
|
raise IXException(
|
|
IXErrorCode.IX_002_000,
|
|
detail=f"ollama non-JSON body: {resp.text[:_BODY_SNIPPET_MAX_CHARS]}",
|
|
) from exc
|
|
|
|
content = (payload.get("message") or {}).get("content") or ""
|
|
try:
|
|
parsed = response_schema.model_validate_json(content)
|
|
except ValidationError as exc:
|
|
raise IXException(
|
|
IXErrorCode.IX_002_001,
|
|
detail=(
|
|
f"{response_schema.__name__}: {exc.__class__.__name__}: "
|
|
f"body={content[:_BODY_SNIPPET_MAX_CHARS]}"
|
|
),
|
|
) from exc
|
|
except ValueError as exc:
|
|
# ``model_validate_json`` raises ValueError on invalid JSON (not
|
|
# a ValidationError). Treat as structured-output failure.
|
|
raise IXException(
|
|
IXErrorCode.IX_002_001,
|
|
detail=(
|
|
f"{response_schema.__name__}: invalid JSON: "
|
|
f"body={content[:_BODY_SNIPPET_MAX_CHARS]}"
|
|
),
|
|
) from exc
|
|
|
|
usage = GenAIUsage(
|
|
prompt_tokens=int(payload.get("prompt_eval_count") or 0),
|
|
completion_tokens=int(payload.get("eval_count") or 0),
|
|
)
|
|
model_name = str(payload.get("model") or request_kwargs.get("model") or "")
|
|
return GenAIInvocationResult(parsed=parsed, usage=usage, model_name=model_name)
|
|
|
|
async def selfcheck(
|
|
self, expected_model: str
|
|
) -> Literal["ok", "degraded", "fail"]:
|
|
"""Probe ``/api/tags`` for ``/healthz``.
|
|
|
|
``ok`` when the server answers 2xx and ``expected_model`` is listed;
|
|
``degraded`` when reachable but the model is missing; ``fail``
|
|
otherwise. Spec §5, §11.
|
|
"""
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=_OLLAMA_TAGS_TIMEOUT_S) as http:
|
|
resp = await http.get(f"{self._base_url}/api/tags")
|
|
except (httpx.HTTPError, ConnectionError, TimeoutError):
|
|
return "fail"
|
|
|
|
if resp.status_code != 200:
|
|
return "fail"
|
|
|
|
try:
|
|
payload = resp.json()
|
|
except ValueError:
|
|
return "fail"
|
|
|
|
models = payload.get("models") or []
|
|
names = {str(entry.get("name", "")) for entry in models}
|
|
if expected_model in names:
|
|
return "ok"
|
|
return "degraded"
|
|
|
|
def _translate_request(
|
|
self,
|
|
request_kwargs: dict[str, Any],
|
|
response_schema: type[BaseModel],
|
|
) -> dict[str, Any]:
|
|
"""Map provider-neutral kwargs to Ollama's /api/chat body."""
|
|
|
|
messages = self._translate_messages(
|
|
list(request_kwargs.get("messages") or [])
|
|
)
|
|
body: dict[str, Any] = {
|
|
"model": request_kwargs.get("model"),
|
|
"messages": messages,
|
|
"stream": False,
|
|
"format": response_schema.model_json_schema(),
|
|
}
|
|
|
|
options: dict[str, Any] = {}
|
|
if "temperature" in request_kwargs:
|
|
options["temperature"] = request_kwargs["temperature"]
|
|
# reasoning_effort intentionally dropped — Ollama doesn't support it.
|
|
if options:
|
|
body["options"] = options
|
|
return body
|
|
|
|
@staticmethod
|
|
def _translate_messages(
|
|
messages: list[dict[str, Any]],
|
|
) -> list[dict[str, Any]]:
|
|
"""Collapse content-parts lists into single strings for Ollama."""
|
|
out: list[dict[str, Any]] = []
|
|
for msg in messages:
|
|
content = msg.get("content")
|
|
if isinstance(content, list):
|
|
text_parts = [
|
|
str(part.get("text", ""))
|
|
for part in content
|
|
if isinstance(part, dict) and part.get("type") == "text"
|
|
]
|
|
new_content = "\n".join(text_parts)
|
|
else:
|
|
new_content = content
|
|
out.append({**msg, "content": new_content})
|
|
return out
|
|
|
|
|
|
__all__ = ["OllamaClient"]
|