diff --git a/src/ix/genai/ollama_client.py b/src/ix/genai/ollama_client.py new file mode 100644 index 0000000..21ca11e --- /dev/null +++ b/src/ix/genai/ollama_client.py @@ -0,0 +1,203 @@ +"""OllamaClient — real :class:`GenAIClient` implementation (spec §6 GenAIStep). + +Wraps the Ollama ``/api/chat`` structured-output endpoint. Per spec: + +* POST ``{base_url}/api/chat`` with ``format = ``, + ``stream = false``, and ``options`` carrying provider-neutral knobs + (``temperature`` mapped, ``reasoning_effort`` dropped — Ollama ignores it). +* Messages are passed through. Content-parts lists (``[{"type":"text",...}]``) + are joined to a single string because MVP models (``gpt-oss:20b`` / + ``qwen2.5:32b``) don't accept native content-parts. +* Per-call timeout is enforced via ``httpx``. A connection refusal, read + timeout, or 5xx maps to ``IX_002_000``. A 2xx whose ``message.content`` is + not valid JSON for the schema maps to ``IX_002_001``. + +``selfcheck()`` targets ``/api/tags`` with a fixed 5 s timeout and is what +``/healthz`` consumes. +""" + +from __future__ import annotations + +from typing import Any, Literal + +import httpx +from pydantic import BaseModel, ValidationError + +from ix.errors import IXErrorCode, IXException +from ix.genai.client import GenAIInvocationResult, GenAIUsage + +_OLLAMA_TAGS_TIMEOUT_S: float = 5.0 +_BODY_SNIPPET_MAX_CHARS: int = 240 + + +class OllamaClient: + """Async Ollama backend satisfying :class:`~ix.genai.client.GenAIClient`. + + Parameters + ---------- + base_url: + Root URL of the Ollama server (e.g. ``http://host.docker.internal:11434``). + Trailing slashes are stripped. + per_call_timeout_s: + Hard per-call timeout for ``/api/chat``. Spec default: 1500 s. + """ + + def __init__(self, base_url: str, per_call_timeout_s: float) -> None: + self._base_url = base_url.rstrip("/") + self._per_call_timeout_s = per_call_timeout_s + + async def invoke( + self, + request_kwargs: dict[str, Any], + response_schema: type[BaseModel], + ) -> GenAIInvocationResult: + """Run one structured-output chat call; parse into ``response_schema``.""" + + body = self._translate_request(request_kwargs, response_schema) + url = f"{self._base_url}/api/chat" + + try: + async with httpx.AsyncClient(timeout=self._per_call_timeout_s) as http: + resp = await http.post(url, json=body) + except httpx.HTTPError as exc: + raise IXException( + IXErrorCode.IX_002_000, + detail=f"ollama {exc.__class__.__name__}: {exc}", + ) from exc + except (ConnectionError, TimeoutError) as exc: # pragma: no cover - httpx wraps these + raise IXException( + IXErrorCode.IX_002_000, + detail=f"ollama {exc.__class__.__name__}: {exc}", + ) from exc + + if resp.status_code >= 500: + raise IXException( + IXErrorCode.IX_002_000, + detail=( + f"ollama HTTP {resp.status_code}: " + f"{resp.text[:_BODY_SNIPPET_MAX_CHARS]}" + ), + ) + if resp.status_code >= 400: + raise IXException( + IXErrorCode.IX_002_000, + detail=( + f"ollama HTTP {resp.status_code}: " + f"{resp.text[:_BODY_SNIPPET_MAX_CHARS]}" + ), + ) + + try: + payload = resp.json() + except ValueError as exc: + raise IXException( + IXErrorCode.IX_002_000, + detail=f"ollama non-JSON body: {resp.text[:_BODY_SNIPPET_MAX_CHARS]}", + ) from exc + + content = (payload.get("message") or {}).get("content") or "" + try: + parsed = response_schema.model_validate_json(content) + except ValidationError as exc: + raise IXException( + IXErrorCode.IX_002_001, + detail=( + f"{response_schema.__name__}: {exc.__class__.__name__}: " + f"body={content[:_BODY_SNIPPET_MAX_CHARS]}" + ), + ) from exc + except ValueError as exc: + # ``model_validate_json`` raises ValueError on invalid JSON (not + # a ValidationError). Treat as structured-output failure. + raise IXException( + IXErrorCode.IX_002_001, + detail=( + f"{response_schema.__name__}: invalid JSON: " + f"body={content[:_BODY_SNIPPET_MAX_CHARS]}" + ), + ) from exc + + usage = GenAIUsage( + prompt_tokens=int(payload.get("prompt_eval_count") or 0), + completion_tokens=int(payload.get("eval_count") or 0), + ) + model_name = str(payload.get("model") or request_kwargs.get("model") or "") + return GenAIInvocationResult(parsed=parsed, usage=usage, model_name=model_name) + + async def selfcheck( + self, expected_model: str + ) -> Literal["ok", "degraded", "fail"]: + """Probe ``/api/tags`` for ``/healthz``. + + ``ok`` when the server answers 2xx and ``expected_model`` is listed; + ``degraded`` when reachable but the model is missing; ``fail`` + otherwise. Spec §5, §11. + """ + + try: + async with httpx.AsyncClient(timeout=_OLLAMA_TAGS_TIMEOUT_S) as http: + resp = await http.get(f"{self._base_url}/api/tags") + except (httpx.HTTPError, ConnectionError, TimeoutError): + return "fail" + + if resp.status_code != 200: + return "fail" + + try: + payload = resp.json() + except ValueError: + return "fail" + + models = payload.get("models") or [] + names = {str(entry.get("name", "")) for entry in models} + if expected_model in names: + return "ok" + return "degraded" + + def _translate_request( + self, + request_kwargs: dict[str, Any], + response_schema: type[BaseModel], + ) -> dict[str, Any]: + """Map provider-neutral kwargs to Ollama's /api/chat body.""" + + messages = self._translate_messages( + list(request_kwargs.get("messages") or []) + ) + body: dict[str, Any] = { + "model": request_kwargs.get("model"), + "messages": messages, + "stream": False, + "format": response_schema.model_json_schema(), + } + + options: dict[str, Any] = {} + if "temperature" in request_kwargs: + options["temperature"] = request_kwargs["temperature"] + # reasoning_effort intentionally dropped — Ollama doesn't support it. + if options: + body["options"] = options + return body + + @staticmethod + def _translate_messages( + messages: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """Collapse content-parts lists into single strings for Ollama.""" + out: list[dict[str, Any]] = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + text_parts = [ + str(part.get("text", "")) + for part in content + if isinstance(part, dict) and part.get("type") == "text" + ] + new_content = "\n".join(text_parts) + else: + new_content = content + out.append({**msg, "content": new_content}) + return out + + +__all__ = ["OllamaClient"] diff --git a/tests/live/__init__.py b/tests/live/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/live/test_ollama_client_live.py b/tests/live/test_ollama_client_live.py new file mode 100644 index 0000000..dd9eb26 --- /dev/null +++ b/tests/live/test_ollama_client_live.py @@ -0,0 +1,70 @@ +"""Live tests for :class:`OllamaClient` — gated on ``IX_TEST_OLLAMA=1``. + +Never runs in CI (Forgejo runner has no LAN access to Ollama). Run locally:: + + IX_TEST_OLLAMA=1 uv run pytest tests/live/test_ollama_client_live.py -v + +Assumes the Ollama server at ``http://192.168.68.42:11434`` already has +``gpt-oss:20b`` pulled. +""" + +from __future__ import annotations + +import os + +import pytest + +from ix.genai.ollama_client import OllamaClient +from ix.use_cases.bank_statement_header import BankStatementHeader + +pytestmark = [ + pytest.mark.live, + pytest.mark.skipif( + os.environ.get("IX_TEST_OLLAMA") != "1", + reason="live: IX_TEST_OLLAMA=1 required", + ), +] + +_OLLAMA_URL = "http://192.168.68.42:11434" +_MODEL = "gpt-oss:20b" + + +async def test_structured_output_round_trip() -> None: + """Real Ollama returns a parsed BankStatementHeader instance.""" + client = OllamaClient(base_url=_OLLAMA_URL, per_call_timeout_s=300.0) + result = await client.invoke( + request_kwargs={ + "model": _MODEL, + "messages": [ + { + "role": "system", + "content": ( + "You extract bank statement header fields. " + "Return valid JSON matching the given schema. " + "Do not invent values." + ), + }, + { + "role": "user", + "content": ( + "Bank: Deutsche Kreditbank (DKB)\n" + "Currency: EUR\n" + "IBAN: DE89370400440532013000\n" + "Period: 2025-01-01 to 2025-01-31" + ), + }, + ], + }, + response_schema=BankStatementHeader, + ) + assert isinstance(result.parsed, BankStatementHeader) + assert isinstance(result.parsed.bank_name, str) + assert result.parsed.bank_name # non-empty + assert isinstance(result.parsed.currency, str) + assert result.model_name # server echoes a model name + + +async def test_selfcheck_ok_against_real_server() -> None: + """``selfcheck`` returns ``ok`` when the target model is pulled.""" + client = OllamaClient(base_url=_OLLAMA_URL, per_call_timeout_s=5.0) + assert await client.selfcheck(expected_model=_MODEL) == "ok" diff --git a/tests/unit/test_ollama_client.py b/tests/unit/test_ollama_client.py new file mode 100644 index 0000000..bf6254a --- /dev/null +++ b/tests/unit/test_ollama_client.py @@ -0,0 +1,258 @@ +"""Tests for :class:`OllamaClient` — hermetic, pytest-httpx-driven. + +Covers spec §6 GenAIStep Ollama call contract: + +* POST body shape (model / messages / format / stream / options). +* Response parsing → :class:`GenAIInvocationResult`. +* Error mapping: connection / timeout / 5xx → ``IX_002_000``; + schema-violating body → ``IX_002_001``. +* ``selfcheck()``: tags-reachable + model-listed → ``ok``; + reachable-but-missing → ``degraded``; unreachable → ``fail``. +""" + +from __future__ import annotations + +import httpx +import pytest +from pydantic import BaseModel +from pytest_httpx import HTTPXMock + +from ix.errors import IXErrorCode, IXException +from ix.genai.ollama_client import OllamaClient + + +class _Schema(BaseModel): + """Trivial structured-output schema for the round-trip tests.""" + + bank_name: str + account_number: str | None = None + + +def _ollama_chat_ok_body(content_json: str) -> dict: + """Build a minimal Ollama /api/chat success body.""" + return { + "model": "gpt-oss:20b", + "message": {"role": "assistant", "content": content_json}, + "done": True, + "eval_count": 42, + "prompt_eval_count": 17, + } + + +class TestInvokeHappyPath: + async def test_posts_to_chat_endpoint_with_format_and_no_stream( + self, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + url="http://ollama.test:11434/api/chat", + method="POST", + json=_ollama_chat_ok_body('{"bank_name":"DKB","account_number":"DE89"}'), + ) + + client = OllamaClient( + base_url="http://ollama.test:11434", per_call_timeout_s=5.0 + ) + result = await client.invoke( + request_kwargs={ + "model": "gpt-oss:20b", + "messages": [ + {"role": "system", "content": "You extract."}, + {"role": "user", "content": "Doc body"}, + ], + "temperature": 0.2, + "reasoning_effort": "high", # dropped silently + }, + response_schema=_Schema, + ) + + assert result.parsed == _Schema(bank_name="DKB", account_number="DE89") + assert result.model_name == "gpt-oss:20b" + assert result.usage.prompt_tokens == 17 + assert result.usage.completion_tokens == 42 + + # Verify request shape. + requests = httpx_mock.get_requests() + assert len(requests) == 1 + body = requests[0].read().decode() + import json + + body_json = json.loads(body) + assert body_json["model"] == "gpt-oss:20b" + assert body_json["stream"] is False + assert body_json["format"] == _Schema.model_json_schema() + assert body_json["options"]["temperature"] == 0.2 + assert "reasoning_effort" not in body_json + assert body_json["messages"] == [ + {"role": "system", "content": "You extract."}, + {"role": "user", "content": "Doc body"}, + ] + + async def test_text_parts_content_list_is_joined( + self, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + url="http://ollama.test:11434/api/chat", + method="POST", + json=_ollama_chat_ok_body('{"bank_name":"X"}'), + ) + client = OllamaClient( + base_url="http://ollama.test:11434", per_call_timeout_s=5.0 + ) + await client.invoke( + request_kwargs={ + "model": "gpt-oss:20b", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "part-a"}, + {"type": "text", "text": "part-b"}, + ], + } + ], + }, + response_schema=_Schema, + ) + import json + + request_body = json.loads(httpx_mock.get_requests()[0].read()) + assert request_body["messages"] == [ + {"role": "user", "content": "part-a\npart-b"} + ] + + +class TestInvokeErrorPaths: + async def test_connection_error_maps_to_002_000( + self, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_exception(httpx.ConnectError("refused")) + client = OllamaClient( + base_url="http://ollama.test:11434", per_call_timeout_s=1.0 + ) + with pytest.raises(IXException) as ei: + await client.invoke( + request_kwargs={ + "model": "gpt-oss:20b", + "messages": [{"role": "user", "content": "hi"}], + }, + response_schema=_Schema, + ) + assert ei.value.code is IXErrorCode.IX_002_000 + + async def test_read_timeout_maps_to_002_000(self, httpx_mock: HTTPXMock) -> None: + httpx_mock.add_exception(httpx.ReadTimeout("slow")) + client = OllamaClient( + base_url="http://ollama.test:11434", per_call_timeout_s=0.5 + ) + with pytest.raises(IXException) as ei: + await client.invoke( + request_kwargs={ + "model": "gpt-oss:20b", + "messages": [{"role": "user", "content": "hi"}], + }, + response_schema=_Schema, + ) + assert ei.value.code is IXErrorCode.IX_002_000 + + async def test_500_maps_to_002_000_with_body_snippet( + self, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + url="http://ollama.test:11434/api/chat", + method="POST", + status_code=500, + text="boom boom server broken", + ) + client = OllamaClient( + base_url="http://ollama.test:11434", per_call_timeout_s=5.0 + ) + with pytest.raises(IXException) as ei: + await client.invoke( + request_kwargs={ + "model": "gpt-oss:20b", + "messages": [{"role": "user", "content": "hi"}], + }, + response_schema=_Schema, + ) + assert ei.value.code is IXErrorCode.IX_002_000 + assert "boom" in (ei.value.detail or "") + + async def test_200_with_invalid_json_maps_to_002_001( + self, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + url="http://ollama.test:11434/api/chat", + method="POST", + json=_ollama_chat_ok_body("not-json"), + ) + client = OllamaClient( + base_url="http://ollama.test:11434", per_call_timeout_s=5.0 + ) + with pytest.raises(IXException) as ei: + await client.invoke( + request_kwargs={ + "model": "gpt-oss:20b", + "messages": [{"role": "user", "content": "hi"}], + }, + response_schema=_Schema, + ) + assert ei.value.code is IXErrorCode.IX_002_001 + + async def test_200_with_schema_violation_maps_to_002_001( + self, httpx_mock: HTTPXMock + ) -> None: + # Missing required `bank_name` field. + httpx_mock.add_response( + url="http://ollama.test:11434/api/chat", + method="POST", + json=_ollama_chat_ok_body('{"account_number":"DE89"}'), + ) + client = OllamaClient( + base_url="http://ollama.test:11434", per_call_timeout_s=5.0 + ) + with pytest.raises(IXException) as ei: + await client.invoke( + request_kwargs={ + "model": "gpt-oss:20b", + "messages": [{"role": "user", "content": "hi"}], + }, + response_schema=_Schema, + ) + assert ei.value.code is IXErrorCode.IX_002_001 + + +class TestSelfcheck: + async def test_selfcheck_ok_when_model_listed( + self, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + url="http://ollama.test:11434/api/tags", + method="GET", + json={"models": [{"name": "gpt-oss:20b"}, {"name": "qwen2.5:32b"}]}, + ) + client = OllamaClient( + base_url="http://ollama.test:11434", per_call_timeout_s=5.0 + ) + assert await client.selfcheck(expected_model="gpt-oss:20b") == "ok" + + async def test_selfcheck_degraded_when_model_missing( + self, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + url="http://ollama.test:11434/api/tags", + method="GET", + json={"models": [{"name": "qwen2.5:32b"}]}, + ) + client = OllamaClient( + base_url="http://ollama.test:11434", per_call_timeout_s=5.0 + ) + assert await client.selfcheck(expected_model="gpt-oss:20b") == "degraded" + + async def test_selfcheck_fail_on_connection_error( + self, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_exception(httpx.ConnectError("refused")) + client = OllamaClient( + base_url="http://ollama.test:11434", per_call_timeout_s=5.0 + ) + assert await client.selfcheck(expected_model="gpt-oss:20b") == "fail"