diff --git a/pyproject.toml b/pyproject.toml index 986ad2b..1c766e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest>=7.0", + "pytest-cov>=4.0", "ruff>=0.8", ] diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..bcd844a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,85 @@ +"""Shared test fixtures for ChainWeaver.""" + +from __future__ import annotations + +import pytest +from helpers import ( + FormattedOutput, + NumberInput, + ValueInput, + ValueOutput, + _add_ten_fn, + _double_fn, + _format_fn, +) + +from chainweaver.executor import FlowExecutor +from chainweaver.flow import Flow, FlowStep +from chainweaver.registry import FlowRegistry +from chainweaver.tools import Tool + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def double_tool() -> Tool: + return Tool( + name="double", + description="Doubles a number.", + input_schema=NumberInput, + output_schema=ValueOutput, + fn=_double_fn, + ) + + +@pytest.fixture() +def add_ten_tool() -> Tool: + return Tool( + name="add_ten", + description="Adds 10 to a value.", + input_schema=ValueInput, + output_schema=ValueOutput, + fn=_add_ten_fn, + ) + + +@pytest.fixture() +def format_tool() -> Tool: + return Tool( + name="format_result", + description="Formats a value.", + input_schema=ValueInput, + output_schema=FormattedOutput, + fn=_format_fn, + ) + + +@pytest.fixture() +def linear_flow() -> Flow: + return Flow( + name="double_add_format", + description="Doubles a number, adds 10, and formats the result.", + steps=[ + FlowStep(tool_name="double", input_mapping={"number": "number"}), + FlowStep(tool_name="add_ten", input_mapping={"value": "value"}), + FlowStep(tool_name="format_result", input_mapping={"value": "value"}), + ], + ) + + +@pytest.fixture() +def executor( + linear_flow: Flow, + double_tool: Tool, + add_ten_tool: Tool, + format_tool: Tool, +) -> FlowExecutor: + registry = FlowRegistry() + registry.register_flow(linear_flow) + ex = FlowExecutor(registry=registry) + ex.register_tool(double_tool) + ex.register_tool(add_ten_tool) + ex.register_tool(format_tool) + return ex diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..7186026 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,42 @@ +"""Shared Pydantic schemas and helper functions for ChainWeaver tests.""" + +from __future__ import annotations + +from pydantic import BaseModel + +# --------------------------------------------------------------------------- +# Shared Pydantic schemas +# --------------------------------------------------------------------------- + + +class NumberInput(BaseModel): + number: int + + +class ValueOutput(BaseModel): + value: int + + +class ValueInput(BaseModel): + value: int + + +class FormattedOutput(BaseModel): + result: str + + +# --------------------------------------------------------------------------- +# Shared tool functions +# --------------------------------------------------------------------------- + + +def _double_fn(inp: NumberInput) -> dict: + return {"value": inp.number * 2} + + +def _add_ten_fn(inp: ValueInput) -> dict: + return {"value": inp.value + 10} + + +def _format_fn(inp: ValueInput) -> dict: + return {"result": f"Final value: {inp.value}"} diff --git a/tests/test_flow_execution.py b/tests/test_flow_execution.py index 31f1652..f3e4c26 100644 --- a/tests/test_flow_execution.py +++ b/tests/test_flow_execution.py @@ -3,6 +3,12 @@ from __future__ import annotations import pytest +from helpers import ( + FormattedOutput, + NumberInput, + ValueOutput, + _double_fn, +) from pydantic import BaseModel, ValidationError from chainweaver.exceptions import ( @@ -17,101 +23,6 @@ from chainweaver.registry import FlowRegistry from chainweaver.tools import Tool -# --------------------------------------------------------------------------- -# Shared fixtures -# --------------------------------------------------------------------------- - - -class NumberInput(BaseModel): - number: int - - -class ValueOutput(BaseModel): - value: int - - -class ValueInput(BaseModel): - value: int - - -class FormattedOutput(BaseModel): - result: str - - -def _double_fn(inp: NumberInput) -> dict: - return {"value": inp.number * 2} - - -def _add_ten_fn(inp: ValueInput) -> dict: - return {"value": inp.value + 10} - - -def _format_fn(inp: ValueInput) -> dict: - return {"result": f"Final value: {inp.value}"} - - -@pytest.fixture() -def double_tool() -> Tool: - return Tool( - name="double", - description="Doubles a number.", - input_schema=NumberInput, - output_schema=ValueOutput, - fn=_double_fn, - ) - - -@pytest.fixture() -def add_ten_tool() -> Tool: - return Tool( - name="add_ten", - description="Adds 10 to a value.", - input_schema=ValueInput, - output_schema=ValueOutput, - fn=_add_ten_fn, - ) - - -@pytest.fixture() -def format_tool() -> Tool: - return Tool( - name="format_result", - description="Formats a value.", - input_schema=ValueInput, - output_schema=FormattedOutput, - fn=_format_fn, - ) - - -@pytest.fixture() -def linear_flow() -> Flow: - return Flow( - name="double_add_format", - description="Doubles a number, adds 10, and formats the result.", - steps=[ - FlowStep(tool_name="double", input_mapping={"number": "number"}), - FlowStep(tool_name="add_ten", input_mapping={"value": "value"}), - FlowStep(tool_name="format_result", input_mapping={"value": "value"}), - ], - ) - - -@pytest.fixture() -def executor( - linear_flow: Flow, - double_tool: Tool, - add_ten_tool: Tool, - format_tool: Tool, -) -> FlowExecutor: - registry = FlowRegistry() - registry.register_flow(linear_flow) - ex = FlowExecutor(registry=registry) - ex.register_tool(double_tool) - ex.register_tool(add_ten_tool) - ex.register_tool(format_tool) - return ex - - # --------------------------------------------------------------------------- # Successful execution # --------------------------------------------------------------------------- @@ -173,7 +84,7 @@ def test_tool_not_found_fails_step(self, linear_flow: Flow) -> None: registry = FlowRegistry() registry.register_flow(linear_flow) ex = FlowExecutor(registry=registry) - # No tools registered — step 0 should fail gracefully. + # No tools registered \u2014 step 0 should fail gracefully. result = ex.execute_flow("double_add_format", {"number": 5}) assert result.success is False assert len(result.execution_log) == 1 @@ -367,7 +278,7 @@ def sum_fn(inp: CtxInput) -> dict: class TestFlowExecutionError: - """Tool fn raises a generic exception → wrapped as FlowExecutionError.""" + """Tool fn raises a generic exception \u2192 wrapped as FlowExecutionError.""" def test_runtime_error_wrapped(self) -> None: class InSchema(BaseModel): @@ -659,3 +570,144 @@ class StrictOutput(BaseModel): assert len(result.execution_log) == 1 assert result.execution_log[0].step_index == 0 # len(steps) == 0 assert isinstance(result.execution_log[0].error, SchemaValidationError) + + +# --------------------------------------------------------------------------- +# Single-step flow +# --------------------------------------------------------------------------- + + +class TestSingleStepFlow: + """A flow with exactly one step \u2014 simplest chaining case.""" + + def test_single_step_succeeds( + self, + double_tool: Tool, + ) -> None: + flow = Flow( + name="single_step", + description="One-step flow that doubles a number.", + steps=[ + FlowStep(tool_name="double", input_mapping={"number": "number"}), + ], + ) + registry = FlowRegistry() + registry.register_flow(flow) + ex = FlowExecutor(registry=registry) + ex.register_tool(double_tool) + + result = ex.execute_flow("single_step", {"number": 7}) + assert result.success is True + assert result.final_output is not None + assert result.final_output["value"] == 14 + assert len(result.execution_log) == 1 + assert result.execution_log[0].tool_name == "double" + + +# --------------------------------------------------------------------------- +# Context accumulation +# --------------------------------------------------------------------------- + + +class TestContextAccumulation: + """Verify that outputs from *all* steps are merged into final_output.""" + + def test_context_accumulates_all_outputs( + self, + executor: FlowExecutor, + ) -> None: + result = executor.execute_flow("double_add_format", {"number": 5}) + assert result.success is True + assert result.final_output is not None + # Initial input key is preserved. + assert "number" in result.final_output + assert result.final_output["number"] == 5 + # Intermediate key: both double and add_ten write "value"; + # 20 (from add_ten) confirms last-write-wins merge semantics. + assert "value" in result.final_output + assert result.final_output["value"] == 20 + # Final key from format_result step. + assert "result" in result.final_output + assert result.final_output["result"] == "Final value: 20" + + +# --------------------------------------------------------------------------- +# Tool runtime exception: ZeroDivisionError +# --------------------------------------------------------------------------- + + +class TestToolZeroDivisionError: + """A ZeroDivisionError inside a tool fn is wrapped as FlowExecutionError.""" + + def test_zero_division_error_wrapped(self) -> None: + class DivInput(BaseModel): + numerator: int + denominator: int + + class DivOutput(BaseModel): + result: int + + def divide_fn(inp: DivInput) -> dict: + return {"result": inp.numerator // inp.denominator} + + tool = Tool( + name="divide", + description="Integer division.", + input_schema=DivInput, + output_schema=DivOutput, + fn=divide_fn, + ) + flow = Flow( + name="divide_flow", + description="Flow that divides.", + steps=[ + FlowStep( + tool_name="divide", + input_mapping={ + "numerator": "numerator", + "denominator": "denominator", + }, + ) + ], + ) + registry = FlowRegistry() + registry.register_flow(flow) + ex = FlowExecutor(registry=registry) + ex.register_tool(tool) + + result = ex.execute_flow("divide_flow", {"numerator": 10, "denominator": 0}) + assert result.success is False + record = result.execution_log[0] + assert record.success is False + assert isinstance(record.error, FlowExecutionError) + assert "integer division or modulo by zero" in str(record.error) + + +# --------------------------------------------------------------------------- +# Boundary values: negative numbers and zero +# --------------------------------------------------------------------------- + + +class TestBoundaryValues: + """Negative numbers and zero through the double\u2192add\u2192format chain.""" + + def test_negative_input(self, executor: FlowExecutor) -> None: + result = executor.execute_flow("double_add_format", {"number": -3}) + # double(-3) \u2192 -6, add_ten(-6) \u2192 4, format(4) \u2192 "Final value: 4" + assert result.success is True + assert result.final_output is not None + assert result.final_output["result"] == "Final value: 4" + + def test_large_positive_input(self, executor: FlowExecutor) -> None: + result = executor.execute_flow("double_add_format", {"number": 1000}) + # double(1000)\u21922000, add_ten(2000)\u21922010, format\u2192"Final value: 2010" + assert result.success is True + assert result.final_output is not None + assert result.final_output["result"] == "Final value: 2010" + + def test_large_negative_input(self, executor: FlowExecutor) -> None: + result = executor.execute_flow("double_add_format", {"number": -1000}) + # double(-1000) \u2192 -2000, add_ten(-2000) \u2192 -1990 + assert result.success is True + assert result.final_output is not None + assert result.final_output["result"] == "Final value: -1990" diff --git a/tests/test_registry.py b/tests/test_registry.py index 868255a..a8d36ac 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -123,3 +123,27 @@ def test_match_is_case_insensitive(self) -> None: registry.register_flow(flow) match = registry.match_flow_by_intent("uppercase") assert match is not None + + def test_empty_registry_returns_none(self) -> None: + """An empty registry has nothing to match.""" + registry = FlowRegistry() + assert registry.match_flow_by_intent("anything") is None + + +# --------------------------------------------------------------------------- +# Overwrite preserves count +# --------------------------------------------------------------------------- + + +class TestOverwritePreservesCount: + def test_register_flow_then_overwrite_preserves_count(self) -> None: + registry = FlowRegistry() + registry.register_flow(_make_flow("keep")) + registry.register_flow(_make_flow("replace_me")) + assert len(registry) == 2 + + new_flow = _make_flow("replace_me") + new_flow.description = "Replaced" + registry.register_flow(new_flow, overwrite=True) + assert len(registry) == 2 + assert registry.get_flow("replace_me").description == "Replaced"