"""Tests for the Pipeline orchestrator + Step ABC + Timer (spec §3, §4).""" from __future__ import annotations from typing import Any import pytest from ix.contracts import Context, RequestIX, ResponseIX from ix.errors import IXErrorCode, IXException from ix.pipeline import Pipeline, Step def _make_request() -> RequestIX: return RequestIX( use_case="bank_statement_header", ix_client_id="test", request_id="r-1", context=Context(files=[], texts=["hello"]), ) class StubStep(Step): """Hand-written Step double. Records call order on a shared list.""" def __init__( self, name: str, *, calls: list[str] | None = None, validate_returns: bool = True, validate_raises: IXException | None = None, process_raises: IXException | BaseException | None = None, mutate: Any = None, ) -> None: self.name = name self._calls = calls if calls is not None else [] self._validate_returns = validate_returns self._validate_raises = validate_raises self._process_raises = process_raises self._mutate = mutate @property def step_name(self) -> str: # type: ignore[override] return self.name async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool: self._calls.append(f"{self.name}.validate") if self._validate_raises is not None: raise self._validate_raises return self._validate_returns async def process( self, request_ix: RequestIX, response_ix: ResponseIX ) -> ResponseIX: self._calls.append(f"{self.name}.process") if self._process_raises is not None: raise self._process_raises if self._mutate is not None: self._mutate(response_ix) return response_ix class TestOrdering: async def test_steps_run_in_registered_order(self) -> None: calls: list[str] = [] pipeline = Pipeline( steps=[ StubStep("a", calls=calls), StubStep("b", calls=calls), StubStep("c", calls=calls), ] ) await pipeline.start(_make_request()) assert calls == [ "a.validate", "a.process", "b.validate", "b.process", "c.validate", "c.process", ] class TestSkipOnFalse: async def test_validate_false_skips_process(self) -> None: calls: list[str] = [] pipeline = Pipeline( steps=[ StubStep("a", calls=calls, validate_returns=False), StubStep("b", calls=calls), ] ) response = await pipeline.start(_make_request()) assert calls == ["a.validate", "b.validate", "b.process"] assert response.error is None class TestErrorFromValidate: async def test_validate_raising_ix_sets_error_and_aborts(self) -> None: calls: list[str] = [] err = IXException(IXErrorCode.IX_000_002, detail="nothing") pipeline = Pipeline( steps=[ StubStep("a", calls=calls, validate_raises=err), StubStep("b", calls=calls), ] ) response = await pipeline.start(_make_request()) assert calls == ["a.validate"] assert response.error is not None assert response.error.startswith("IX_000_002") assert "nothing" in response.error class TestErrorFromProcess: async def test_process_raising_ix_sets_error_and_aborts(self) -> None: calls: list[str] = [] err = IXException(IXErrorCode.IX_002_001, detail="bad parse") pipeline = Pipeline( steps=[ StubStep("a", calls=calls, process_raises=err), StubStep("b", calls=calls), ] ) response = await pipeline.start(_make_request()) assert calls == ["a.validate", "a.process"] assert response.error is not None assert response.error.startswith("IX_002_001") assert "bad parse" in response.error class TestTimings: async def test_timings_populated_for_every_executed_step(self) -> None: calls: list[str] = [] pipeline = Pipeline( steps=[ StubStep("alpha", calls=calls), StubStep("beta", calls=calls), ] ) response = await pipeline.start(_make_request()) # One timing per executed step. names = [t["step"] for t in response.metadata.timings] assert names == ["alpha", "beta"] for entry in response.metadata.timings: assert isinstance(entry["elapsed_seconds"], float) assert entry["elapsed_seconds"] >= 0.0 async def test_timing_recorded_even_when_validate_false(self) -> None: calls: list[str] = [] pipeline = Pipeline( steps=[ StubStep("skipper", calls=calls, validate_returns=False), StubStep("runner", calls=calls), ] ) response = await pipeline.start(_make_request()) names = [t["step"] for t in response.metadata.timings] assert names == ["skipper", "runner"] async def test_timing_recorded_when_validate_raises(self) -> None: calls: list[str] = [] err = IXException(IXErrorCode.IX_001_001, detail="missing") pipeline = Pipeline( steps=[ StubStep("boom", calls=calls, validate_raises=err), StubStep("unreached", calls=calls), ] ) response = await pipeline.start(_make_request()) names = [t["step"] for t in response.metadata.timings] assert names == ["boom"] class TestNonIXExceptionStillSurfaces: async def test_generic_exception_in_process_aborts(self) -> None: calls: list[str] = [] pipeline = Pipeline( steps=[ StubStep("a", calls=calls, process_raises=RuntimeError("kaboom")), StubStep("b", calls=calls), ] ) with pytest.raises(RuntimeError): await pipeline.start(_make_request()) assert calls == ["a.validate", "a.process"] class TestStepMutation: async def test_response_ix_shared_across_steps(self) -> None: def mutate_a(resp: ResponseIX) -> None: resp.use_case = "set-by-a" def mutate_b(resp: ResponseIX) -> None: # b sees what a wrote. assert resp.use_case == "set-by-a" resp.use_case_name = "set-by-b" pipeline = Pipeline( steps=[ StubStep("a", mutate=mutate_a), StubStep("b", mutate=mutate_b), ] ) response = await pipeline.start(_make_request()) assert response.use_case == "set-by-a" assert response.use_case_name == "set-by-b"