diff --git a/src/distillery/mcp/tools/_common.py b/src/distillery/mcp/tools/_common.py index dd1ae605..f7d07e7e 100644 --- a/src/distillery/mcp/tools/_common.py +++ b/src/distillery/mcp/tools/_common.py @@ -100,10 +100,22 @@ def success_response(data: dict[str, Any]) -> list[types.TextContent]: def validate_required(arguments: dict[str, Any], *fields: str) -> str | None: - """Return an error message if any required field is missing from *arguments*. + """Return an error message if any required field is absent, empty, or blank. - A field is considered missing if it is absent, ``None``, or (for strings) - empty. Values such as ``0`` and ``False`` are **not** treated as missing. + Distinguishes two failure modes so that agents parsing the error message + can recover without retrying with the same payload: + + * **Missing** — field is absent from ``arguments`` or set to ``None``. + Reported as ``"Missing required fields: ..."``. + * **Empty** — field is present as a string that is empty or whitespace-only. + Reported as ``"Field '...' must be a non-empty string"`` (or the plural + form for multiple fields). + + Non-string falsy values (``0``, ``False``, ``[]``, ``{}``) are **not** + treated as missing or empty — those are valid inputs for fields that + accept them. + + When both categories fail in the same call, missing is reported first. Args: arguments: The tool argument dict. @@ -111,15 +123,22 @@ def validate_required(arguments: dict[str, Any], *fields: str) -> str | None: Returns: An error message string if validation fails, or ``None`` if all fields - are present. + are present and non-empty. """ - missing = [ - f - for f in fields - if arguments.get(f) is None or (isinstance(arguments.get(f), str) and not arguments.get(f)) - ] + missing: list[str] = [] + empty: list[str] = [] + for f in fields: + value = arguments.get(f) + if value is None: + missing.append(f) + elif isinstance(value, str) and not value.strip(): + empty.append(f) if missing: return f"Missing required fields: {', '.join(missing)}" + if empty: + if len(empty) == 1: + return f"Field {empty[0]!r} must be a non-empty string" + return f"Fields must be non-empty strings: {', '.join(repr(f) for f in empty)}" return None diff --git a/src/distillery/mcp/tools/search.py b/src/distillery/mcp/tools/search.py index 17ad6a2b..a4636064 100644 --- a/src/distillery/mcp/tools/search.py +++ b/src/distillery/mcp/tools/search.py @@ -346,12 +346,13 @@ async def _handle_aggregate( MCP content list with a JSON payload containing ``group_by``, ``groups``, ``total_entries``, and ``total_groups``. """ - group_by = arguments.get("group_by", "") err_group_by = validate_type(arguments, "group_by", str, "string") if err_group_by: return error_response("INVALID_PARAMS", err_group_by) - if not group_by: - return error_response("INVALID_PARAMS", "Missing required field: group_by") + err = validate_required(arguments, "group_by") + if err: + return error_response("INVALID_PARAMS", err) + group_by: str = arguments["group_by"] if group_by not in _AGGREGATE_GROUP_BY_MAP: return error_response( "INVALID_PARAMS", diff --git a/tests/test_mcp_classify.py b/tests/test_mcp_classify.py index 0c51049d..c055d17d 100644 --- a/tests/test_mcp_classify.py +++ b/tests/test_mcp_classify.py @@ -287,6 +287,26 @@ async def test_classify_validates_confidence_out_of_range( assert data["error"] is True assert data["code"] == "INVALID_PARAMS" + async def test_classify_reports_invalid_entry_type_before_id_lookup( + self, store: DuckDBStore, config: DistilleryConfig + ) -> None: + """Issue #372: when both entry_id and entry_type are bad, schema + validation must fire before the DB lookup so the agent sees the full + set of errors and does not burn a retry fixing one at a time.""" + response = await _handle_classify( + store, + config, + { + "entry_id": "00000000-0000-0000-0000-000000000000", + "entry_type": "invalid_type_xyz", + "confidence": 0.5, + }, + ) + data = parse_mcp_response(response) + assert data["error"] is True + assert data["code"] == "INVALID_PARAMS" + assert data["details"]["field"] == "entry_type" + # --------------------------------------------------------------------------- # distillery_review_queue tests @@ -572,6 +592,23 @@ async def test_resolve_validates_invalid_action(self, store: DuckDBStore) -> Non assert data["error"] is True assert data["code"] == "INVALID_PARAMS" + async def test_resolve_reports_invalid_action_before_id_lookup( + self, store: DuckDBStore + ) -> None: + """Issue #372: bad id + bad action should surface the action error + rather than short-circuiting on NOT_FOUND for the phantom id.""" + response = await _handle_resolve_review( + store, + { + "entry_id": "00000000-0000-0000-0000-000000000000", + "action": "nuke_from_orbit", + }, + ) + data = parse_mcp_response(response) + assert data["error"] is True + assert data["code"] == "INVALID_PARAMS" + assert "nuke_from_orbit" in data["message"] + async def test_resolve_approve_without_reviewer(self, store: DuckDBStore) -> None: """Reviewer field is optional.""" entry = make_entry(status=EntryStatus.PENDING_REVIEW) diff --git a/tests/test_mcp_common.py b/tests/test_mcp_common.py new file mode 100644 index 00000000..c1db63e4 --- /dev/null +++ b/tests/test_mcp_common.py @@ -0,0 +1,159 @@ +"""Unit tests for distillery.mcp.tools._common validation helpers. + +Covers ``validate_required``, ``validate_type``, ``validate_enum``, and +``validate_positive_int``. These helpers shape every INVALID_PARAMS error +returned by the MCP tools, so their contract is asserted here explicitly — +agents rely on the error message to recover without replaying identical +payloads (issue #371). +""" + +from __future__ import annotations + +import pytest + +from distillery.mcp.tools._common import ( + validate_enum, + validate_positive_int, + validate_required, + validate_type, +) + +pytestmark = pytest.mark.unit + + +class TestValidateRequired: + """Tests for ``validate_required`` — field presence and non-emptiness.""" + + def test_all_fields_present_returns_none(self) -> None: + assert validate_required({"a": "x", "b": "y"}, "a", "b") is None + + def test_single_absent_field(self) -> None: + msg = validate_required({"a": "x"}, "a", "b") + assert msg == "Missing required fields: b" + + def test_multiple_absent_fields_listed_in_order(self) -> None: + msg = validate_required({}, "a", "b", "c") + assert msg == "Missing required fields: a, b, c" + + def test_explicit_none_treated_as_missing(self) -> None: + msg = validate_required({"a": None}, "a") + assert msg == "Missing required fields: a" + + def test_empty_string_reports_non_empty_message(self) -> None: + """Issue #371: empty string is present, not missing.""" + msg = validate_required({"query": ""}, "query") + assert msg == "Field 'query' must be a non-empty string" + + def test_whitespace_only_string_treated_as_empty(self) -> None: + """Issue #371: whitespace-only strings are functionally empty.""" + msg = validate_required({"query": " \t\n"}, "query") + assert msg == "Field 'query' must be a non-empty string" + + def test_multiple_empty_fields(self) -> None: + msg = validate_required({"a": "", "b": ""}, "a", "b") + assert msg == "Fields must be non-empty strings: 'a', 'b'" + + def test_missing_takes_precedence_over_empty(self) -> None: + """When both categories fail, report missing first so clients fix the + more fundamental issue before retrying.""" + msg = validate_required({"a": ""}, "a", "b") + assert msg == "Missing required fields: b" + + def test_zero_is_not_treated_as_missing(self) -> None: + """Guards the #245 hint in issue #371: ``0`` is a valid int payload.""" + assert validate_required({"n": 0}, "n") is None + + def test_false_is_not_treated_as_missing(self) -> None: + assert validate_required({"flag": False}, "flag") is None + + def test_empty_list_is_not_treated_as_missing(self) -> None: + """Non-string falsy values are valid — only strings get the empty check.""" + assert validate_required({"tags": []}, "tags") is None + + def test_empty_dict_is_not_treated_as_missing(self) -> None: + assert validate_required({"meta": {}}, "meta") is None + + def test_non_string_non_none_value_passes(self) -> None: + assert validate_required({"n": 42, "flag": True}, "n", "flag") is None + + +class TestValidateType: + """Tests for ``validate_type``.""" + + def test_correct_type_returns_none(self) -> None: + assert validate_type({"x": [1, 2]}, "x", list, "list") is None + + def test_wrong_type_returns_message(self) -> None: + msg = validate_type({"x": "not-a-list"}, "x", list, "list of strings") + assert msg == "Field 'x' must be a list of strings" + + def test_absent_field_is_ok(self) -> None: + """validate_type only runs when the field is present — presence is the + job of validate_required.""" + assert validate_type({}, "x", list, "list") is None + + def test_none_value_is_ok(self) -> None: + assert validate_type({"x": None}, "x", list, "list") is None + + def test_tuple_of_types_accepted(self) -> None: + assert validate_type({"n": 1.5}, "n", (int, float), "number") is None + assert validate_type({"n": 1}, "n", (int, float), "number") is None + + +class TestValidateEnum: + """Tests for ``validate_enum``.""" + + def test_valid_value_returns_none(self) -> None: + assert validate_enum({"action": "approve"}, "action", {"approve", "archive"}) is None + + def test_invalid_value_returns_message(self) -> None: + msg = validate_enum({"action": "nuke"}, "action", {"approve", "archive"}) + assert msg is not None + assert "nuke" in msg + assert "approve" in msg and "archive" in msg + + def test_absent_field_is_ok(self) -> None: + assert validate_enum({}, "action", {"approve"}) is None + + def test_non_string_value_rejected(self) -> None: + """JSON arrays/objects arrive as list/dict — must not raise TypeError.""" + msg = validate_enum({"action": ["approve"]}, "action", {"approve"}) + assert msg is not None + assert "approve" in msg + + def test_label_overrides_field_name_in_message(self) -> None: + msg = validate_enum({"a": "bad"}, "a", {"good"}, label="choice") + assert msg is not None + assert "Invalid choice" in msg + + +class TestValidatePositiveInt: + """Tests for ``validate_positive_int``.""" + + def test_valid_int_returned(self) -> None: + assert validate_positive_int({"limit": 10}, "limit") == 10 + + def test_default_used_when_absent(self) -> None: + assert validate_positive_int({}, "limit", default=5) == 5 + + def test_missing_without_default_returns_error(self) -> None: + result = validate_positive_int({}, "limit") + assert isinstance(result, tuple) + assert "required" in result[1].lower() + + def test_zero_rejected(self) -> None: + result = validate_positive_int({"limit": 0}, "limit") + assert isinstance(result, tuple) + + def test_negative_rejected(self) -> None: + result = validate_positive_int({"limit": -1}, "limit") + assert isinstance(result, tuple) + + def test_non_int_rejected(self) -> None: + result = validate_positive_int({"limit": "10"}, "limit") + assert isinstance(result, tuple) + + def test_bool_rejected(self) -> None: + """``True`` is an int in Python — the validator must reject it anyway.""" + result = validate_positive_int({"limit": True}, "limit") + assert isinstance(result, tuple) diff --git a/tests/test_mcp_coverage_gaps.py b/tests/test_mcp_coverage_gaps.py index ea76d3ec..b7b15f8a 100644 --- a/tests/test_mcp_coverage_gaps.py +++ b/tests/test_mcp_coverage_gaps.py @@ -256,6 +256,21 @@ async def test_aggregate_group_by_not_string(self, store: DuckDBStore) -> None: assert data["error"] is True assert data["code"] == "INVALID_PARAMS" + async def test_aggregate_group_by_missing(self, store: DuckDBStore) -> None: + response = await _handle_aggregate(store, {}) + data = parse_mcp_response(response) + assert data["error"] is True + assert data["code"] == "INVALID_PARAMS" + assert "Missing required fields" in data["message"] + + async def test_aggregate_group_by_empty_string(self, store: DuckDBStore) -> None: + """Issue #371: empty string is present, not missing.""" + response = await _handle_aggregate(store, {"group_by": ""}) + data = parse_mcp_response(response) + assert data["error"] is True + assert data["code"] == "INVALID_PARAMS" + assert "must be a non-empty string" in data["message"] + # =========================================================================== # tools/crud.py gap coverage diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 4da3e700..b365068d 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -503,6 +503,27 @@ async def test_search_missing_query_returns_error(self, store: DuckDBStore) -> N data = parse_mcp_response(response) assert data["error"] is True assert data["code"] == "INVALID_PARAMS" + assert "Missing required fields" in data["message"] + + async def test_search_empty_query_reports_non_empty_required(self, store: DuckDBStore) -> None: + """Issue #371: empty string is present, not missing — agents retrying + with the same empty payload would loop forever on the old message.""" + response = await _handle_search(store, {"query": ""}) + data = parse_mcp_response(response) + assert data["error"] is True + assert data["code"] == "INVALID_PARAMS" + assert "must be a non-empty string" in data["message"] + assert "Missing required fields" not in data["message"] + + async def test_search_whitespace_query_reports_non_empty_required( + self, store: DuckDBStore + ) -> None: + """Issue #371: whitespace-only queries are functionally empty.""" + response = await _handle_search(store, {"query": " "}) + data = parse_mcp_response(response) + assert data["error"] is True + assert data["code"] == "INVALID_PARAMS" + assert "must be a non-empty string" in data["message"] async def test_search_respects_limit(self, store: DuckDBStore) -> None: for i in range(10):