Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions src/distillery/mcp/tools/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,45 @@ def success_response(data: dict[str, Any]) -> list[types.TextContent]:


def validate_required(arguments: dict[str, Any], *fields: str) -> str | None:
"""Return an error message if any required field is missing from *arguments*.
"""Return an error message if any required field is absent, empty, or blank.

A field is considered missing if it is absent, ``None``, or (for strings)
empty. Values such as ``0`` and ``False`` are **not** treated as missing.
Distinguishes two failure modes so that agents parsing the error message
can recover without retrying with the same payload:

* **Missing** — field is absent from ``arguments`` or set to ``None``.
Reported as ``"Missing required fields: ..."``.
* **Empty** — field is present as a string that is empty or whitespace-only.
Reported as ``"Field '...' must be a non-empty string"`` (or the plural
form for multiple fields).

Non-string falsy values (``0``, ``False``, ``[]``, ``{}``) are **not**
treated as missing or empty — those are valid inputs for fields that
accept them.

When both categories fail in the same call, missing is reported first.

Args:
arguments: The tool argument dict.
*fields: Field names that must be present and non-empty.

Returns:
An error message string if validation fails, or ``None`` if all fields
are present.
are present and non-empty.
"""
missing = [
f
for f in fields
if arguments.get(f) is None or (isinstance(arguments.get(f), str) and not arguments.get(f))
]
missing: list[str] = []
empty: list[str] = []
for f in fields:
value = arguments.get(f)
if value is None:
missing.append(f)
elif isinstance(value, str) and not value.strip():
empty.append(f)
if missing:
return f"Missing required fields: {', '.join(missing)}"
if empty:
if len(empty) == 1:
return f"Field {empty[0]!r} must be a non-empty string"
return f"Fields must be non-empty strings: {', '.join(repr(f) for f in empty)}"
return None


Expand Down
7 changes: 4 additions & 3 deletions src/distillery/mcp/tools/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,13 @@ async def _handle_aggregate(
MCP content list with a JSON payload containing ``group_by``,
``groups``, ``total_entries``, and ``total_groups``.
"""
group_by = arguments.get("group_by", "")
err_group_by = validate_type(arguments, "group_by", str, "string")
if err_group_by:
return error_response("INVALID_PARAMS", err_group_by)
if not group_by:
return error_response("INVALID_PARAMS", "Missing required field: group_by")
err = validate_required(arguments, "group_by")
if err:
return error_response("INVALID_PARAMS", err)
group_by: str = arguments["group_by"]
if group_by not in _AGGREGATE_GROUP_BY_MAP:
return error_response(
"INVALID_PARAMS",
Expand Down
37 changes: 37 additions & 0 deletions tests/test_mcp_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,26 @@ async def test_classify_validates_confidence_out_of_range(
assert data["error"] is True
assert data["code"] == "INVALID_PARAMS"

async def test_classify_reports_invalid_entry_type_before_id_lookup(
self, store: DuckDBStore, config: DistilleryConfig
) -> None:
"""Issue #372: when both entry_id and entry_type are bad, schema
validation must fire before the DB lookup so the agent sees the full
set of errors and does not burn a retry fixing one at a time."""
response = await _handle_classify(
store,
config,
{
"entry_id": "00000000-0000-0000-0000-000000000000",
"entry_type": "invalid_type_xyz",
"confidence": 0.5,
},
)
data = parse_mcp_response(response)
assert data["error"] is True
assert data["code"] == "INVALID_PARAMS"
assert data["details"]["field"] == "entry_type"


# ---------------------------------------------------------------------------
# distillery_review_queue tests
Expand Down Expand Up @@ -572,6 +592,23 @@ async def test_resolve_validates_invalid_action(self, store: DuckDBStore) -> Non
assert data["error"] is True
assert data["code"] == "INVALID_PARAMS"

async def test_resolve_reports_invalid_action_before_id_lookup(
self, store: DuckDBStore
) -> None:
"""Issue #372: bad id + bad action should surface the action error
rather than short-circuiting on NOT_FOUND for the phantom id."""
response = await _handle_resolve_review(
store,
{
"entry_id": "00000000-0000-0000-0000-000000000000",
"action": "nuke_from_orbit",
},
)
data = parse_mcp_response(response)
assert data["error"] is True
assert data["code"] == "INVALID_PARAMS"
assert "nuke_from_orbit" in data["message"]

async def test_resolve_approve_without_reviewer(self, store: DuckDBStore) -> None:
"""Reviewer field is optional."""
entry = make_entry(status=EntryStatus.PENDING_REVIEW)
Expand Down
159 changes: 159 additions & 0 deletions tests/test_mcp_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Unit tests for distillery.mcp.tools._common validation helpers.

Covers ``validate_required``, ``validate_type``, ``validate_enum``, and
``validate_positive_int``. These helpers shape every INVALID_PARAMS error
returned by the MCP tools, so their contract is asserted here explicitly —
agents rely on the error message to recover without replaying identical
payloads (issue #371).
"""

from __future__ import annotations

import pytest

from distillery.mcp.tools._common import (
validate_enum,
validate_positive_int,
validate_required,
validate_type,
)

pytestmark = pytest.mark.unit


class TestValidateRequired:
"""Tests for ``validate_required`` — field presence and non-emptiness."""

def test_all_fields_present_returns_none(self) -> None:
assert validate_required({"a": "x", "b": "y"}, "a", "b") is None

def test_single_absent_field(self) -> None:
msg = validate_required({"a": "x"}, "a", "b")
assert msg == "Missing required fields: b"

def test_multiple_absent_fields_listed_in_order(self) -> None:
msg = validate_required({}, "a", "b", "c")
assert msg == "Missing required fields: a, b, c"

def test_explicit_none_treated_as_missing(self) -> None:
msg = validate_required({"a": None}, "a")
assert msg == "Missing required fields: a"

def test_empty_string_reports_non_empty_message(self) -> None:
"""Issue #371: empty string is present, not missing."""
msg = validate_required({"query": ""}, "query")
assert msg == "Field 'query' must be a non-empty string"

def test_whitespace_only_string_treated_as_empty(self) -> None:
"""Issue #371: whitespace-only strings are functionally empty."""
msg = validate_required({"query": " \t\n"}, "query")
assert msg == "Field 'query' must be a non-empty string"

def test_multiple_empty_fields(self) -> None:
msg = validate_required({"a": "", "b": ""}, "a", "b")
assert msg == "Fields must be non-empty strings: 'a', 'b'"

def test_missing_takes_precedence_over_empty(self) -> None:
"""When both categories fail, report missing first so clients fix the
more fundamental issue before retrying."""
msg = validate_required({"a": ""}, "a", "b")
assert msg == "Missing required fields: b"

def test_zero_is_not_treated_as_missing(self) -> None:
"""Guards the #245 hint in issue #371: ``0`` is a valid int payload."""
assert validate_required({"n": 0}, "n") is None

def test_false_is_not_treated_as_missing(self) -> None:
assert validate_required({"flag": False}, "flag") is None

def test_empty_list_is_not_treated_as_missing(self) -> None:
"""Non-string falsy values are valid — only strings get the empty check."""
assert validate_required({"tags": []}, "tags") is None

def test_empty_dict_is_not_treated_as_missing(self) -> None:
assert validate_required({"meta": {}}, "meta") is None

def test_non_string_non_none_value_passes(self) -> None:
assert validate_required({"n": 42, "flag": True}, "n", "flag") is None


class TestValidateType:
"""Tests for ``validate_type``."""

def test_correct_type_returns_none(self) -> None:
assert validate_type({"x": [1, 2]}, "x", list, "list") is None

def test_wrong_type_returns_message(self) -> None:
msg = validate_type({"x": "not-a-list"}, "x", list, "list of strings")
assert msg == "Field 'x' must be a list of strings"

def test_absent_field_is_ok(self) -> None:
"""validate_type only runs when the field is present — presence is the
job of validate_required."""
assert validate_type({}, "x", list, "list") is None

def test_none_value_is_ok(self) -> None:
assert validate_type({"x": None}, "x", list, "list") is None

def test_tuple_of_types_accepted(self) -> None:
assert validate_type({"n": 1.5}, "n", (int, float), "number") is None
assert validate_type({"n": 1}, "n", (int, float), "number") is None


class TestValidateEnum:
"""Tests for ``validate_enum``."""

def test_valid_value_returns_none(self) -> None:
assert validate_enum({"action": "approve"}, "action", {"approve", "archive"}) is None

def test_invalid_value_returns_message(self) -> None:
msg = validate_enum({"action": "nuke"}, "action", {"approve", "archive"})
assert msg is not None
assert "nuke" in msg
assert "approve" in msg and "archive" in msg

def test_absent_field_is_ok(self) -> None:
assert validate_enum({}, "action", {"approve"}) is None

def test_non_string_value_rejected(self) -> None:
"""JSON arrays/objects arrive as list/dict — must not raise TypeError."""
msg = validate_enum({"action": ["approve"]}, "action", {"approve"})
assert msg is not None
assert "approve" in msg

def test_label_overrides_field_name_in_message(self) -> None:
msg = validate_enum({"a": "bad"}, "a", {"good"}, label="choice")
assert msg is not None
assert "Invalid choice" in msg


class TestValidatePositiveInt:
"""Tests for ``validate_positive_int``."""

def test_valid_int_returned(self) -> None:
assert validate_positive_int({"limit": 10}, "limit") == 10

def test_default_used_when_absent(self) -> None:
assert validate_positive_int({}, "limit", default=5) == 5

def test_missing_without_default_returns_error(self) -> None:
result = validate_positive_int({}, "limit")
assert isinstance(result, tuple)
assert "required" in result[1].lower()

def test_zero_rejected(self) -> None:
result = validate_positive_int({"limit": 0}, "limit")
assert isinstance(result, tuple)

def test_negative_rejected(self) -> None:
result = validate_positive_int({"limit": -1}, "limit")
assert isinstance(result, tuple)

def test_non_int_rejected(self) -> None:
result = validate_positive_int({"limit": "10"}, "limit")
assert isinstance(result, tuple)

def test_bool_rejected(self) -> None:
"""``True`` is an int in Python — the validator must reject it anyway."""
result = validate_positive_int({"limit": True}, "limit")
assert isinstance(result, tuple)
15 changes: 15 additions & 0 deletions tests/test_mcp_coverage_gaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,21 @@ async def test_aggregate_group_by_not_string(self, store: DuckDBStore) -> None:
assert data["error"] is True
assert data["code"] == "INVALID_PARAMS"

async def test_aggregate_group_by_missing(self, store: DuckDBStore) -> None:
response = await _handle_aggregate(store, {})
data = parse_mcp_response(response)
assert data["error"] is True
assert data["code"] == "INVALID_PARAMS"
assert "Missing required fields" in data["message"]

async def test_aggregate_group_by_empty_string(self, store: DuckDBStore) -> None:
"""Issue #371: empty string is present, not missing."""
response = await _handle_aggregate(store, {"group_by": ""})
data = parse_mcp_response(response)
assert data["error"] is True
assert data["code"] == "INVALID_PARAMS"
assert "must be a non-empty string" in data["message"]


# ===========================================================================
# tools/crud.py gap coverage
Expand Down
21 changes: 21 additions & 0 deletions tests/test_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,27 @@ async def test_search_missing_query_returns_error(self, store: DuckDBStore) -> N
data = parse_mcp_response(response)
assert data["error"] is True
assert data["code"] == "INVALID_PARAMS"
assert "Missing required fields" in data["message"]

async def test_search_empty_query_reports_non_empty_required(self, store: DuckDBStore) -> None:
"""Issue #371: empty string is present, not missing — agents retrying
with the same empty payload would loop forever on the old message."""
response = await _handle_search(store, {"query": ""})
data = parse_mcp_response(response)
assert data["error"] is True
assert data["code"] == "INVALID_PARAMS"
assert "must be a non-empty string" in data["message"]
assert "Missing required fields" not in data["message"]

async def test_search_whitespace_query_reports_non_empty_required(
self, store: DuckDBStore
) -> None:
"""Issue #371: whitespace-only queries are functionally empty."""
response = await _handle_search(store, {"query": " "})
data = parse_mcp_response(response)
assert data["error"] is True
assert data["code"] == "INVALID_PARAMS"
assert "must be a non-empty string" in data["message"]

async def test_search_respects_limit(self, store: DuckDBStore) -> None:
for i in range(10):
Expand Down
Loading