fix(genai): schema in prompt (#40)
Some checks failed
tests / test (push) Has been cancelled
Some checks failed
tests / test (push) Has been cancelled
This commit is contained in:
commit
763407ba1c
2 changed files with 52 additions and 12 deletions
|
|
@ -162,22 +162,26 @@ class OllamaClient:
|
||||||
"""Map provider-neutral kwargs to Ollama's /api/chat body.
|
"""Map provider-neutral kwargs to Ollama's /api/chat body.
|
||||||
|
|
||||||
Schema strategy for Ollama 0.11.8: we pass ``format="json"`` (loose
|
Schema strategy for Ollama 0.11.8: we pass ``format="json"`` (loose
|
||||||
JSON mode) rather than the full Pydantic schema. The llama.cpp
|
JSON mode) and bake the Pydantic schema into a system message
|
||||||
structured-output implementation in 0.11.8 segfaults on schemas
|
ahead of the caller's own system prompt. Rationale:
|
||||||
involving ``anyOf``, ``$ref``, or ``pattern`` — which Pydantic v2
|
|
||||||
emits for Optional / nested-model / Decimal fields.
|
|
||||||
|
|
||||||
In loose JSON mode Ollama guarantees only syntactically-valid
|
* The full Pydantic schema as ``format=<schema>`` crashes llama.cpp's
|
||||||
JSON; we enforce the schema on our side by catching the Pydantic
|
structured-output implementation (SIGSEGV) on every non-trivial
|
||||||
``ValidationError`` at parse time and raising IX_002_001. The
|
shape — ``anyOf`` / ``$ref`` / ``pattern`` all trigger it.
|
||||||
system prompt (built upstream in GenAIStep) already tells the
|
* ``format="json"`` alone guarantees valid JSON but not the shape;
|
||||||
model what JSON shape to emit, so loose mode is the right
|
models routinely return ``{}`` when not told what fields to emit.
|
||||||
abstraction layer here.
|
* Injecting the schema into the prompt is the cheapest way to
|
||||||
|
get both: the model sees the expected shape explicitly, Pydantic
|
||||||
|
validates the response at parse time (IX_002_001 on mismatch).
|
||||||
|
|
||||||
|
Non-Ollama ``GenAIClient`` impls can ignore this behaviour and use
|
||||||
|
native structured-output (``response_format`` on OpenAI, etc.).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = self._translate_messages(
|
messages = self._translate_messages(
|
||||||
list(request_kwargs.get("messages") or [])
|
list(request_kwargs.get("messages") or [])
|
||||||
)
|
)
|
||||||
|
messages = _inject_schema_system_message(messages, response_schema)
|
||||||
body: dict[str, Any] = {
|
body: dict[str, Any] = {
|
||||||
"model": request_kwargs.get("model"),
|
"model": request_kwargs.get("model"),
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
|
@ -214,6 +218,34 @@ class OllamaClient:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _inject_schema_system_message(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
response_schema: type[BaseModel],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Prepend a system message that pins the expected JSON shape.
|
||||||
|
|
||||||
|
Ollama's ``format="json"`` mode guarantees valid JSON but not the
|
||||||
|
field set or names. We emit the Pydantic schema as JSON and
|
||||||
|
instruct the model to match it. If the caller already provides a
|
||||||
|
system message, we prepend ours; otherwise ours becomes the first
|
||||||
|
system turn.
|
||||||
|
"""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
schema_json = _json.dumps(
|
||||||
|
_sanitise_schema_for_ollama(response_schema.model_json_schema()),
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
guidance = (
|
||||||
|
"Respond ONLY with a single JSON object matching this JSON Schema "
|
||||||
|
"exactly. No prose, no code fences, no explanations. All top-level "
|
||||||
|
"properties listed in `required` MUST be present. Use null for "
|
||||||
|
"fields you cannot confidently extract. The JSON Schema:\n"
|
||||||
|
f"{schema_json}"
|
||||||
|
)
|
||||||
|
return [{"role": "system", "content": guidance}, *messages]
|
||||||
|
|
||||||
|
|
||||||
def _sanitise_schema_for_ollama(schema: Any) -> Any:
|
def _sanitise_schema_for_ollama(schema: Any) -> Any:
|
||||||
"""Strip null branches from ``anyOf`` unions.
|
"""Strip null branches from ``anyOf`` unions.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,12 @@ class TestInvokeHappyPath:
|
||||||
assert body_json["format"] == "json"
|
assert body_json["format"] == "json"
|
||||||
assert body_json["options"]["temperature"] == 0.2
|
assert body_json["options"]["temperature"] == 0.2
|
||||||
assert "reasoning_effort" not in body_json
|
assert "reasoning_effort" not in body_json
|
||||||
assert body_json["messages"] == [
|
# A schema-guidance system message is prepended to the caller's
|
||||||
|
# messages so Ollama (format=json loose mode) emits the right shape.
|
||||||
|
msgs = body_json["messages"]
|
||||||
|
assert msgs[0]["role"] == "system"
|
||||||
|
assert "JSON Schema" in msgs[0]["content"]
|
||||||
|
assert msgs[1:] == [
|
||||||
{"role": "system", "content": "You extract."},
|
{"role": "system", "content": "You extract."},
|
||||||
{"role": "user", "content": "Doc body"},
|
{"role": "user", "content": "Doc body"},
|
||||||
]
|
]
|
||||||
|
|
@ -119,7 +124,10 @@ class TestInvokeHappyPath:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
request_body = json.loads(httpx_mock.get_requests()[0].read())
|
request_body = json.loads(httpx_mock.get_requests()[0].read())
|
||||||
assert request_body["messages"] == [
|
# First message is the auto-injected schema guidance; after that
|
||||||
|
# the caller's user message has its text parts joined.
|
||||||
|
assert request_body["messages"][0]["role"] == "system"
|
||||||
|
assert request_body["messages"][1:] == [
|
||||||
{"role": "user", "content": "part-a\npart-b"}
|
{"role": "user", "content": "part-a\npart-b"}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue