Merge pull request 'feat(genai): OllamaClient structured-output /api/chat backend (spec 6)' (#24) from feat/ollama-client into main
All checks were successful
tests / test (push) Successful in 1m15s
All checks were successful
tests / test (push) Successful in 1m15s
This commit is contained in:
commit
0f045f814a
4 changed files with 531 additions and 0 deletions
203
src/ix/genai/ollama_client.py
Normal file
203
src/ix/genai/ollama_client.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
"""OllamaClient — real :class:`GenAIClient` implementation (spec §6 GenAIStep).
|
||||
|
||||
Wraps the Ollama ``/api/chat`` structured-output endpoint. Per spec:
|
||||
|
||||
* POST ``{base_url}/api/chat`` with ``format = <pydantic JSON schema>``,
|
||||
``stream = false``, and ``options`` carrying provider-neutral knobs
|
||||
(``temperature`` mapped, ``reasoning_effort`` dropped — Ollama ignores it).
|
||||
* Messages are passed through. Content-parts lists (``[{"type":"text",...}]``)
|
||||
are joined to a single string because MVP models (``gpt-oss:20b`` /
|
||||
``qwen2.5:32b``) don't accept native content-parts.
|
||||
* Per-call timeout is enforced via ``httpx``. A connection refusal, read
|
||||
timeout, or 5xx maps to ``IX_002_000``. A 2xx whose ``message.content`` is
|
||||
not valid JSON for the schema maps to ``IX_002_001``.
|
||||
|
||||
``selfcheck()`` targets ``/api/tags`` with a fixed 5 s timeout and is what
|
||||
``/healthz`` consumes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.genai.client import GenAIInvocationResult, GenAIUsage
|
||||
|
||||
_OLLAMA_TAGS_TIMEOUT_S: float = 5.0
|
||||
_BODY_SNIPPET_MAX_CHARS: int = 240
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Async Ollama backend satisfying :class:`~ix.genai.client.GenAIClient`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_url:
|
||||
Root URL of the Ollama server (e.g. ``http://host.docker.internal:11434``).
|
||||
Trailing slashes are stripped.
|
||||
per_call_timeout_s:
|
||||
Hard per-call timeout for ``/api/chat``. Spec default: 1500 s.
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, per_call_timeout_s: float) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._per_call_timeout_s = per_call_timeout_s
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
request_kwargs: dict[str, Any],
|
||||
response_schema: type[BaseModel],
|
||||
) -> GenAIInvocationResult:
|
||||
"""Run one structured-output chat call; parse into ``response_schema``."""
|
||||
|
||||
body = self._translate_request(request_kwargs, response_schema)
|
||||
url = f"{self._base_url}/api/chat"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self._per_call_timeout_s) as http:
|
||||
resp = await http.post(url, json=body)
|
||||
except httpx.HTTPError as exc:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=f"ollama {exc.__class__.__name__}: {exc}",
|
||||
) from exc
|
||||
except (ConnectionError, TimeoutError) as exc: # pragma: no cover - httpx wraps these
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=f"ollama {exc.__class__.__name__}: {exc}",
|
||||
) from exc
|
||||
|
||||
if resp.status_code >= 500:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=(
|
||||
f"ollama HTTP {resp.status_code}: "
|
||||
f"{resp.text[:_BODY_SNIPPET_MAX_CHARS]}"
|
||||
),
|
||||
)
|
||||
if resp.status_code >= 400:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=(
|
||||
f"ollama HTTP {resp.status_code}: "
|
||||
f"{resp.text[:_BODY_SNIPPET_MAX_CHARS]}"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
payload = resp.json()
|
||||
except ValueError as exc:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_000,
|
||||
detail=f"ollama non-JSON body: {resp.text[:_BODY_SNIPPET_MAX_CHARS]}",
|
||||
) from exc
|
||||
|
||||
content = (payload.get("message") or {}).get("content") or ""
|
||||
try:
|
||||
parsed = response_schema.model_validate_json(content)
|
||||
except ValidationError as exc:
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_001,
|
||||
detail=(
|
||||
f"{response_schema.__name__}: {exc.__class__.__name__}: "
|
||||
f"body={content[:_BODY_SNIPPET_MAX_CHARS]}"
|
||||
),
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
# ``model_validate_json`` raises ValueError on invalid JSON (not
|
||||
# a ValidationError). Treat as structured-output failure.
|
||||
raise IXException(
|
||||
IXErrorCode.IX_002_001,
|
||||
detail=(
|
||||
f"{response_schema.__name__}: invalid JSON: "
|
||||
f"body={content[:_BODY_SNIPPET_MAX_CHARS]}"
|
||||
),
|
||||
) from exc
|
||||
|
||||
usage = GenAIUsage(
|
||||
prompt_tokens=int(payload.get("prompt_eval_count") or 0),
|
||||
completion_tokens=int(payload.get("eval_count") or 0),
|
||||
)
|
||||
model_name = str(payload.get("model") or request_kwargs.get("model") or "")
|
||||
return GenAIInvocationResult(parsed=parsed, usage=usage, model_name=model_name)
|
||||
|
||||
async def selfcheck(
|
||||
self, expected_model: str
|
||||
) -> Literal["ok", "degraded", "fail"]:
|
||||
"""Probe ``/api/tags`` for ``/healthz``.
|
||||
|
||||
``ok`` when the server answers 2xx and ``expected_model`` is listed;
|
||||
``degraded`` when reachable but the model is missing; ``fail``
|
||||
otherwise. Spec §5, §11.
|
||||
"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_OLLAMA_TAGS_TIMEOUT_S) as http:
|
||||
resp = await http.get(f"{self._base_url}/api/tags")
|
||||
except (httpx.HTTPError, ConnectionError, TimeoutError):
|
||||
return "fail"
|
||||
|
||||
if resp.status_code != 200:
|
||||
return "fail"
|
||||
|
||||
try:
|
||||
payload = resp.json()
|
||||
except ValueError:
|
||||
return "fail"
|
||||
|
||||
models = payload.get("models") or []
|
||||
names = {str(entry.get("name", "")) for entry in models}
|
||||
if expected_model in names:
|
||||
return "ok"
|
||||
return "degraded"
|
||||
|
||||
def _translate_request(
|
||||
self,
|
||||
request_kwargs: dict[str, Any],
|
||||
response_schema: type[BaseModel],
|
||||
) -> dict[str, Any]:
|
||||
"""Map provider-neutral kwargs to Ollama's /api/chat body."""
|
||||
|
||||
messages = self._translate_messages(
|
||||
list(request_kwargs.get("messages") or [])
|
||||
)
|
||||
body: dict[str, Any] = {
|
||||
"model": request_kwargs.get("model"),
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"format": response_schema.model_json_schema(),
|
||||
}
|
||||
|
||||
options: dict[str, Any] = {}
|
||||
if "temperature" in request_kwargs:
|
||||
options["temperature"] = request_kwargs["temperature"]
|
||||
# reasoning_effort intentionally dropped — Ollama doesn't support it.
|
||||
if options:
|
||||
body["options"] = options
|
||||
return body
|
||||
|
||||
@staticmethod
|
||||
def _translate_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Collapse content-parts lists into single strings for Ollama."""
|
||||
out: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts = [
|
||||
str(part.get("text", ""))
|
||||
for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
]
|
||||
new_content = "\n".join(text_parts)
|
||||
else:
|
||||
new_content = content
|
||||
out.append({**msg, "content": new_content})
|
||||
return out
|
||||
|
||||
|
||||
__all__ = ["OllamaClient"]
|
||||
0
tests/live/__init__.py
Normal file
0
tests/live/__init__.py
Normal file
70
tests/live/test_ollama_client_live.py
Normal file
70
tests/live/test_ollama_client_live.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Live tests for :class:`OllamaClient` — gated on ``IX_TEST_OLLAMA=1``.
|
||||
|
||||
Never runs in CI (Forgejo runner has no LAN access to Ollama). Run locally::
|
||||
|
||||
IX_TEST_OLLAMA=1 uv run pytest tests/live/test_ollama_client_live.py -v
|
||||
|
||||
Assumes the Ollama server at ``http://192.168.68.42:11434`` already has
|
||||
``gpt-oss:20b`` pulled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from ix.genai.ollama_client import OllamaClient
|
||||
from ix.use_cases.bank_statement_header import BankStatementHeader
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.live,
|
||||
pytest.mark.skipif(
|
||||
os.environ.get("IX_TEST_OLLAMA") != "1",
|
||||
reason="live: IX_TEST_OLLAMA=1 required",
|
||||
),
|
||||
]
|
||||
|
||||
_OLLAMA_URL = "http://192.168.68.42:11434"
|
||||
_MODEL = "gpt-oss:20b"
|
||||
|
||||
|
||||
async def test_structured_output_round_trip() -> None:
|
||||
"""Real Ollama returns a parsed BankStatementHeader instance."""
|
||||
client = OllamaClient(base_url=_OLLAMA_URL, per_call_timeout_s=300.0)
|
||||
result = await client.invoke(
|
||||
request_kwargs={
|
||||
"model": _MODEL,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You extract bank statement header fields. "
|
||||
"Return valid JSON matching the given schema. "
|
||||
"Do not invent values."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Bank: Deutsche Kreditbank (DKB)\n"
|
||||
"Currency: EUR\n"
|
||||
"IBAN: DE89370400440532013000\n"
|
||||
"Period: 2025-01-01 to 2025-01-31"
|
||||
),
|
||||
},
|
||||
],
|
||||
},
|
||||
response_schema=BankStatementHeader,
|
||||
)
|
||||
assert isinstance(result.parsed, BankStatementHeader)
|
||||
assert isinstance(result.parsed.bank_name, str)
|
||||
assert result.parsed.bank_name # non-empty
|
||||
assert isinstance(result.parsed.currency, str)
|
||||
assert result.model_name # server echoes a model name
|
||||
|
||||
|
||||
async def test_selfcheck_ok_against_real_server() -> None:
|
||||
"""``selfcheck`` returns ``ok`` when the target model is pulled."""
|
||||
client = OllamaClient(base_url=_OLLAMA_URL, per_call_timeout_s=5.0)
|
||||
assert await client.selfcheck(expected_model=_MODEL) == "ok"
|
||||
258
tests/unit/test_ollama_client.py
Normal file
258
tests/unit/test_ollama_client.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
"""Tests for :class:`OllamaClient` — hermetic, pytest-httpx-driven.
|
||||
|
||||
Covers spec §6 GenAIStep Ollama call contract:
|
||||
|
||||
* POST body shape (model / messages / format / stream / options).
|
||||
* Response parsing → :class:`GenAIInvocationResult`.
|
||||
* Error mapping: connection / timeout / 5xx → ``IX_002_000``;
|
||||
schema-violating body → ``IX_002_001``.
|
||||
* ``selfcheck()``: tags-reachable + model-listed → ``ok``;
|
||||
reachable-but-missing → ``degraded``; unreachable → ``fail``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pytest_httpx import HTTPXMock
|
||||
|
||||
from ix.errors import IXErrorCode, IXException
|
||||
from ix.genai.ollama_client import OllamaClient
|
||||
|
||||
|
||||
class _Schema(BaseModel):
|
||||
"""Trivial structured-output schema for the round-trip tests."""
|
||||
|
||||
bank_name: str
|
||||
account_number: str | None = None
|
||||
|
||||
|
||||
def _ollama_chat_ok_body(content_json: str) -> dict:
|
||||
"""Build a minimal Ollama /api/chat success body."""
|
||||
return {
|
||||
"model": "gpt-oss:20b",
|
||||
"message": {"role": "assistant", "content": content_json},
|
||||
"done": True,
|
||||
"eval_count": 42,
|
||||
"prompt_eval_count": 17,
|
||||
}
|
||||
|
||||
|
||||
class TestInvokeHappyPath:
|
||||
async def test_posts_to_chat_endpoint_with_format_and_no_stream(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/chat",
|
||||
method="POST",
|
||||
json=_ollama_chat_ok_body('{"bank_name":"DKB","account_number":"DE89"}'),
|
||||
)
|
||||
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
result = await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You extract."},
|
||||
{"role": "user", "content": "Doc body"},
|
||||
],
|
||||
"temperature": 0.2,
|
||||
"reasoning_effort": "high", # dropped silently
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
|
||||
assert result.parsed == _Schema(bank_name="DKB", account_number="DE89")
|
||||
assert result.model_name == "gpt-oss:20b"
|
||||
assert result.usage.prompt_tokens == 17
|
||||
assert result.usage.completion_tokens == 42
|
||||
|
||||
# Verify request shape.
|
||||
requests = httpx_mock.get_requests()
|
||||
assert len(requests) == 1
|
||||
body = requests[0].read().decode()
|
||||
import json
|
||||
|
||||
body_json = json.loads(body)
|
||||
assert body_json["model"] == "gpt-oss:20b"
|
||||
assert body_json["stream"] is False
|
||||
assert body_json["format"] == _Schema.model_json_schema()
|
||||
assert body_json["options"]["temperature"] == 0.2
|
||||
assert "reasoning_effort" not in body_json
|
||||
assert body_json["messages"] == [
|
||||
{"role": "system", "content": "You extract."},
|
||||
{"role": "user", "content": "Doc body"},
|
||||
]
|
||||
|
||||
async def test_text_parts_content_list_is_joined(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/chat",
|
||||
method="POST",
|
||||
json=_ollama_chat_ok_body('{"bank_name":"X"}'),
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "part-a"},
|
||||
{"type": "text", "text": "part-b"},
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
import json
|
||||
|
||||
request_body = json.loads(httpx_mock.get_requests()[0].read())
|
||||
assert request_body["messages"] == [
|
||||
{"role": "user", "content": "part-a\npart-b"}
|
||||
]
|
||||
|
||||
|
||||
class TestInvokeErrorPaths:
|
||||
async def test_connection_error_maps_to_002_000(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_exception(httpx.ConnectError("refused"))
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=1.0
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
assert ei.value.code is IXErrorCode.IX_002_000
|
||||
|
||||
async def test_read_timeout_maps_to_002_000(self, httpx_mock: HTTPXMock) -> None:
|
||||
httpx_mock.add_exception(httpx.ReadTimeout("slow"))
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=0.5
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
assert ei.value.code is IXErrorCode.IX_002_000
|
||||
|
||||
async def test_500_maps_to_002_000_with_body_snippet(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/chat",
|
||||
method="POST",
|
||||
status_code=500,
|
||||
text="boom boom server broken",
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
assert ei.value.code is IXErrorCode.IX_002_000
|
||||
assert "boom" in (ei.value.detail or "")
|
||||
|
||||
async def test_200_with_invalid_json_maps_to_002_001(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/chat",
|
||||
method="POST",
|
||||
json=_ollama_chat_ok_body("not-json"),
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
assert ei.value.code is IXErrorCode.IX_002_001
|
||||
|
||||
async def test_200_with_schema_violation_maps_to_002_001(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
# Missing required `bank_name` field.
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/chat",
|
||||
method="POST",
|
||||
json=_ollama_chat_ok_body('{"account_number":"DE89"}'),
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
with pytest.raises(IXException) as ei:
|
||||
await client.invoke(
|
||||
request_kwargs={
|
||||
"model": "gpt-oss:20b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
},
|
||||
response_schema=_Schema,
|
||||
)
|
||||
assert ei.value.code is IXErrorCode.IX_002_001
|
||||
|
||||
|
||||
class TestSelfcheck:
|
||||
async def test_selfcheck_ok_when_model_listed(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/tags",
|
||||
method="GET",
|
||||
json={"models": [{"name": "gpt-oss:20b"}, {"name": "qwen2.5:32b"}]},
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
assert await client.selfcheck(expected_model="gpt-oss:20b") == "ok"
|
||||
|
||||
async def test_selfcheck_degraded_when_model_missing(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
url="http://ollama.test:11434/api/tags",
|
||||
method="GET",
|
||||
json={"models": [{"name": "qwen2.5:32b"}]},
|
||||
)
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
assert await client.selfcheck(expected_model="gpt-oss:20b") == "degraded"
|
||||
|
||||
async def test_selfcheck_fail_on_connection_error(
|
||||
self, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_exception(httpx.ConnectError("refused"))
|
||||
client = OllamaClient(
|
||||
base_url="http://ollama.test:11434", per_call_timeout_s=5.0
|
||||
)
|
||||
assert await client.selfcheck(expected_model="gpt-oss:20b") == "fail"
|
||||
Loading…
Reference in a new issue