From ec04412eaacaec8cb69f4fb10149db6649423be3 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Tue, 31 Mar 2026 23:59:28 +0800 Subject: [PATCH 01/21] fix(agent): remove invoke timeout caps --- src/Undefined/api/app.py | 48 ++++++++++++++++++++-- src/Undefined/skills/agents/__init__.py | 2 +- src/Undefined/skills/registry.py | 28 ++++++++++--- src/Undefined/webui/routes/_runtime.py | 17 ++++++-- tests/test_agent_registry.py | 51 ++++++++++++++++++++++++ tests/test_queue_timeout_budgets.py | 15 +++++++ tests/test_runtime_api_tool_invoke.py | 53 +++++++++++++++++++++++++ 7 files changed, 201 insertions(+), 13 deletions(-) create mode 100644 tests/test_agent_registry.py diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index a0d00a5f..ea33b8b2 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -40,6 +40,10 @@ _NAGA_REQUEST_UUID_TTL_SECONDS = 6 * 60 * 60 +class _ToolInvokeExecutionTimeoutError(asyncio.TimeoutError): + """由 Runtime API 工具调用超时包装器抛出的超时异常。""" + + @dataclass class _NagaRequestResult: payload_hash: str @@ -1509,6 +1513,43 @@ def _is_toolset(name: str) -> bool: return filtered + def _get_agent_tool_names(self) -> set[str]: + ai = self._ctx.ai + if ai is None: + return set() + + agent_reg = getattr(ai, "agent_registry", None) + if agent_reg is None: + return set() + + agent_names: set[str] = set() + for schema in agent_reg.get_agents_schema(): + func = schema.get("function", {}) + name = str(func.get("name", "")) + if name: + agent_names.add(name) + return agent_names + + def _resolve_tool_invoke_timeout( + self, tool_name: str, timeout: int + ) -> float | None: + if tool_name in self._get_agent_tool_names(): + return None + return float(timeout) + + async def _await_tool_invoke_result( + self, + awaitable: Awaitable[Any], + *, + timeout: float | None, + ) -> Any: + if timeout is None or timeout <= 0: + return await awaitable + try: + return await asyncio.wait_for(awaitable, timeout=timeout) + except asyncio.TimeoutError as exc: + raise _ToolInvokeExecutionTimeoutError from exc + async def _tools_list_handler(self, request: web.Request) -> Response: _ = request cfg = self._ctx.config_getter() @@ -1659,6 +1700,7 @@ async def _execute_tool_invoke( ) start = time.perf_counter() + effective_timeout = self._resolve_tool_invoke_timeout(tool_name, timeout) try: async with RequestContext( request_type=request_type, @@ -1699,9 +1741,9 @@ async def _execute_tool_invoke( if tool_manager is None: raise RuntimeError("ToolManager not available") - raw_result = await asyncio.wait_for( + raw_result = await self._await_tool_invoke_result( tool_manager.execute_tool(tool_name, args, tool_context), - timeout=timeout, + timeout=effective_timeout, ) elapsed_ms = round((time.perf_counter() - start) * 1000, 1) @@ -1722,7 +1764,7 @@ async def _execute_tool_invoke( "duration_ms": elapsed_ms, } - except asyncio.TimeoutError: + except _ToolInvokeExecutionTimeoutError: elapsed_ms = round((time.perf_counter() - start) * 1000, 1) logger.warning( "[ToolInvoke] 执行超时: request_id=%s tool=%s timeout=%ds", diff --git a/src/Undefined/skills/agents/__init__.py b/src/Undefined/skills/agents/__init__.py index 0d61511d..88e291db 100644 --- a/src/Undefined/skills/agents/__init__.py +++ b/src/Undefined/skills/agents/__init__.py @@ -26,7 +26,7 @@ def __init__(self, agents_dir: str | Path | None = None) -> None: else: agents_path = Path(agents_dir) - super().__init__(agents_path, kind="agent") + super().__init__(agents_path, kind="agent", timeout_seconds=0.0) self.set_watch_filenames( {"config.json", "handler.py", "intro.md", "intro.generated.md"} ) diff --git a/src/Undefined/skills/registry.py b/src/Undefined/skills/registry.py index 3f1c2adf..a197e4ac 100644 --- a/src/Undefined/skills/registry.py +++ b/src/Undefined/skills/registry.py @@ -14,6 +14,10 @@ logger = logging.getLogger(__name__) +class RegistryExecutionTimeoutError(asyncio.TimeoutError): + """由注册表超时包装器抛出的超时异常。""" + + @dataclass class SkillStats: """技能执行统计数据类 @@ -108,6 +112,9 @@ def set_watch_paths(self, paths: List[Path]) -> None: def set_watch_filenames(self, filenames: set[str]) -> None: self._watch_filenames = filenames + def _has_timeout(self) -> bool: + return self.timeout_seconds > 0 + def _log_event(self, event: str, name: str = "", **fields: Any) -> None: parts = [f"event={event}", f"kind={self.kind}"] if name: @@ -366,7 +373,7 @@ async def execute( result_payload = result return_value = str(result) - except asyncio.TimeoutError: + except RegistryExecutionTimeoutError: duration = time.monotonic() - start_time self._stats[name].record_failure(duration, "timeout") self._log_event( @@ -413,13 +420,24 @@ async def _execute_with_timeout( args: Dict[str, Any], context: Dict[str, Any], ) -> Any: + if not self._has_timeout(): + if asyncio.iscoroutinefunction(handler): + return await handler(args, context) + return await asyncio.to_thread(handler, args, context) + if asyncio.iscoroutinefunction(handler): + try: + return await asyncio.wait_for( + handler(args, context), timeout=self.timeout_seconds + ) + except asyncio.TimeoutError as exc: + raise RegistryExecutionTimeoutError from exc + try: return await asyncio.wait_for( - handler(args, context), timeout=self.timeout_seconds + asyncio.to_thread(handler, args, context), timeout=self.timeout_seconds ) - return await asyncio.wait_for( - asyncio.to_thread(handler, args, context), timeout=self.timeout_seconds - ) + except asyncio.TimeoutError as exc: + raise RegistryExecutionTimeoutError from exc def _compute_snapshot(self) -> Dict[str, tuple[int, int]]: snapshot: Dict[str, tuple[int, int]] = {} diff --git a/src/Undefined/webui/routes/_runtime.py b/src/Undefined/webui/routes/_runtime.py index f816a00c..2d2043a4 100644 --- a/src/Undefined/webui/routes/_runtime.py +++ b/src/Undefined/webui/routes/_runtime.py @@ -36,6 +36,16 @@ def _chat_proxy_timeout_seconds() -> float: return compute_queued_llm_timeout_seconds(cfg, cfg.chat_model) +def _tool_invoke_proxy_timeout_seconds(tool_name: str) -> float | None: + normalized_name = str(tool_name or "").strip() + if normalized_name.endswith("_agent"): + return None + + cfg = get_config(strict=False) + # 普通 tool 保持 Runtime API 超时 + 60s 网络缓冲。 + return float(cfg.api.tool_invoke_timeout) + 60.0 + + def _unauthorized() -> Response: return web.json_response({"error": "Unauthorized"}, status=401) @@ -82,7 +92,7 @@ async def _proxy_runtime( path: str, params: Mapping[str, str] | None = None, payload: dict[str, Any] | None = None, - timeout_seconds: float = 20.0, + timeout_seconds: float | None = 20.0, ) -> Response: cfg = get_config(strict=False) if not cfg.api.enabled: @@ -416,9 +426,8 @@ async def runtime_tools_invoke_handler(request: web.Request) -> Response: except Exception: return web.json_response({"error": "Invalid JSON"}, status=400) - cfg = get_config(strict=False) - # 代理超时 = 工具调用超时 + 60s 缓冲(覆盖网络开销) - proxy_timeout = float(cfg.api.tool_invoke_timeout) + 60.0 + tool_name = str(body.get("tool_name", "") or "").strip() + proxy_timeout = _tool_invoke_proxy_timeout_seconds(tool_name) return await _proxy_runtime( method="POST", path="/api/v1/tools/invoke", diff --git a/tests/test_agent_registry.py b/tests/test_agent_registry.py new file mode 100644 index 00000000..77989a2d --- /dev/null +++ b/tests/test_agent_registry.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import asyncio +import json +from pathlib import Path + +import pytest + +from Undefined.skills.agents import AgentRegistry + + +@pytest.mark.asyncio +async def test_agent_registry_executes_without_registry_timeout( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + agent_dir = tmp_path / "demo_agent" + agent_dir.mkdir() + (agent_dir / "config.json").write_text( + json.dumps( + { + "type": "function", + "function": { + "name": "demo_agent", + "description": "demo", + "parameters": {"type": "object", "properties": {}}, + }, + } + ), + encoding="utf-8", + ) + (agent_dir / "handler.py").write_text( + "async def execute(args, context):\n return 'ok'\n", + encoding="utf-8", + ) + + registry = AgentRegistry(tmp_path) + original_wait_for = asyncio.wait_for + seen: dict[str, float] = {} + + async def _wait_for(awaitable, timeout): # type: ignore[no-untyped-def] + seen["timeout"] = timeout + return await original_wait_for(awaitable, timeout) + + monkeypatch.setattr("Undefined.skills.registry.asyncio.wait_for", _wait_for) + + result = await registry.execute_agent("demo_agent", {}, {}) + + assert result == "ok" + assert registry.timeout_seconds == 0.0 + assert "timeout" not in seen diff --git a/tests/test_queue_timeout_budgets.py b/tests/test_queue_timeout_budgets.py index 53d217d3..32b16091 100644 --- a/tests/test_queue_timeout_budgets.py +++ b/tests/test_queue_timeout_budgets.py @@ -132,3 +132,18 @@ def test_chat_proxy_timeout_uses_queue_budget(monkeypatch: pytest.MonkeyPatch) - cfg.chat_model, ) ) + + +def test_tool_invoke_proxy_timeout_skips_agents( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cfg = SimpleNamespace(api=SimpleNamespace(tool_invoke_timeout=120)) + monkeypatch.setattr( + runtime_routes, "get_config", lambda strict=False: cast(Any, cfg) + ) + + assert runtime_routes._tool_invoke_proxy_timeout_seconds("web_agent") is None + assert ( + runtime_routes._tool_invoke_proxy_timeout_seconds("messages.send_message") + == 180.0 + ) diff --git a/tests/test_runtime_api_tool_invoke.py b/tests/test_runtime_api_tool_invoke.py index 7af831eb..9db13303 100644 --- a/tests/test_runtime_api_tool_invoke.py +++ b/tests/test_runtime_api_tool_invoke.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json from types import SimpleNamespace from typing import Any, cast @@ -352,6 +353,58 @@ async def test_invoke_sync_success() -> None: assert "duration_ms" in payload +@pytest.mark.asyncio +async def test_invoke_tool_uses_runtime_timeout( + monkeypatch: pytest.MonkeyPatch, +) -> None: + server = _make_server(_make_api_cfg(tool_invoke_timeout=7)) + original_wait_for = asyncio.wait_for + seen: dict[str, float] = {} + + async def _wait_for(awaitable: Any, timeout: float) -> Any: + seen["timeout"] = timeout + return await original_wait_for(awaitable, timeout) + + monkeypatch.setattr("Undefined.api.app.asyncio.wait_for", _wait_for) + + payload = await server._execute_tool_invoke( + request_id="req-tool", + tool_name="get_current_time", + args={}, + body_context=None, + timeout=7, + ) + + assert payload["ok"] is True + assert seen["timeout"] == 7.0 + + +@pytest.mark.asyncio +async def test_invoke_agent_bypasses_runtime_timeout( + monkeypatch: pytest.MonkeyPatch, +) -> None: + server = _make_server(_make_api_cfg(tool_invoke_timeout=7)) + original_wait_for = asyncio.wait_for + seen: dict[str, float] = {} + + async def _wait_for(awaitable: Any, timeout: float) -> Any: + seen["timeout"] = timeout + return await original_wait_for(awaitable, timeout) + + monkeypatch.setattr("Undefined.api.app.asyncio.wait_for", _wait_for) + + payload = await server._execute_tool_invoke( + request_id="req-agent", + tool_name="web_agent", + args={}, + body_context=None, + timeout=7, + ) + + assert payload["ok"] is True + assert "timeout" not in seen + + @pytest.mark.asyncio async def test_invoke_with_context() -> None: server = _make_server() From 1c3061febd26c077ddc9aaf966c06f8481ced437 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Wed, 1 Apr 2026 00:06:48 +0800 Subject: [PATCH 02/21] fix(image-gen): support base64 responses and preserve size --- .../tools/ai_draw_one/config.json | 11 +- .../tools/ai_draw_one/handler.py | 287 +++++++++++++++--- tests/test_ai_draw_one_handler.py | 141 +++++++++ 3 files changed, 391 insertions(+), 48 deletions(-) create mode 100644 tests/test_ai_draw_one_handler.py diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json index b19fd10b..ebce1dda 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json @@ -16,7 +16,7 @@ }, "size": { "type": "string", - "description": "比例/尺寸 (例如 1:1, 2:3, 1024x1024;可选,默认由配置决定)" + "description": "比例/尺寸。星之阁模式可用 1:1 等比例;models 模式仅支持 1280x720、720x1280、1792x1024、1024x1792、1024x1024" }, "quality": { "type": "string", @@ -26,6 +26,15 @@ "type": "string", "description": "图片风格 (例如 vivid, natural;可选,仅 OpenAI 模式生效)" }, + "response_format": { + "type": "string", + "description": "图片响应格式(仅 models 模式生效)", + "enum": ["url", "b64_json", "base64"] + }, + "n": { + "type": "integer", + "description": "生成图片数量(仅 models 模式生效,1 到 10)" + }, "target_id": { "type": "integer", "description": "发送目标的 ID" diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py index 59e19bad..70b9aefb 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py @@ -7,9 +7,12 @@ from __future__ import annotations +import base64 +import binascii import logging import time import uuid +from dataclasses import dataclass from typing import Any from Undefined.skills.http_client import request_with_retry @@ -17,6 +20,21 @@ logger = logging.getLogger(__name__) +_ALLOWED_MODELS_IMAGE_SIZES = ( + "1280x720", + "720x1280", + "1792x1024", + "1024x1792", + "1024x1024", +) +_ALLOWED_IMAGE_RESPONSE_FORMATS = ("url", "b64_json", "base64") + + +@dataclass +class _GeneratedImagePayload: + image_url: str | None = None + image_bytes: bytes | None = None + def _record_image_gen_usage( model_name: str, prompt: str, duration_seconds: float, success: bool @@ -57,6 +75,149 @@ def _parse_image_url(data: dict[str, Any]) -> str | None: return None +def _parse_image_bytes(data: dict[str, Any]) -> bytes | None: + """从 API 响应中提取 Base64 图片内容。""" + try: + image_item = data["data"][0] + except (KeyError, IndexError, TypeError): + return None + + if not isinstance(image_item, dict): + return None + + for key in ("b64_json", "base64"): + raw_value = image_item.get(key) + if raw_value is None: + continue + text = str(raw_value).strip() + if not text: + continue + try: + return base64.b64decode(text) + except (binascii.Error, ValueError): + logger.error("图片 Base64 解码失败: key=%s", key) + return None + return None + + +def _parse_generated_image(data: dict[str, Any]) -> _GeneratedImagePayload | None: + image_url = _parse_image_url(data) + if image_url: + return _GeneratedImagePayload(image_url=image_url) + + image_bytes = _parse_image_bytes(data) + if image_bytes is not None: + return _GeneratedImagePayload(image_bytes=image_bytes) + + return None + + +def _coerce_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return value != 0 + return str(value or "").strip().lower() in {"1", "true", "yes", "on"} + + +def _build_openai_models_request_body( + *, + prompt: str, + model_name: str, + size: str, + quality: str, + style: str, + response_format: str, + n: int | None, + extra_params: dict[str, Any], +) -> dict[str, Any]: + from Undefined.utils.request_params import merge_request_params + + body = merge_request_params(extra_params) + body["prompt"] = prompt + if n is not None: + body["n"] = n + else: + body.setdefault("n", 1) + if model_name: + body["model"] = model_name + if size: + body["size"] = size + if quality: + body["quality"] = quality + if style: + body["style"] = style + if response_format: + body["response_format"] = response_format + return body + + +def _validate_openai_models_request_body(body: dict[str, Any]) -> str | None: + size = str(body.get("size", "") or "").strip() + if size and size not in _ALLOWED_MODELS_IMAGE_SIZES: + supported = ", ".join(_ALLOWED_MODELS_IMAGE_SIZES) + return f"size 无效:{size}。models provider 仅支持: {supported}" + + response_format = str(body.get("response_format", "") or "").strip() + if response_format and response_format not in _ALLOWED_IMAGE_RESPONSE_FORMATS: + supported = ", ".join(_ALLOWED_IMAGE_RESPONSE_FORMATS) + return f"response_format 无效:{response_format}。仅支持: {supported}" + + raw_n = body.get("n", 1) + try: + n = int(raw_n) + except (TypeError, ValueError): + return f"n 无效:{raw_n}。必须是 1 到 10 的整数" + + if not 1 <= n <= 10: + return f"n 无效:{n}。必须是 1 到 10 的整数" + + if _coerce_bool(body.get("stream")): + if n not in {1, 2}: + return "stream=true 时 n 只能是 1 或 2" + return "暂不支持 stream=true 的绘图响应" + + return None + + +def _detect_image_suffix(image_bytes: bytes) -> str: + if image_bytes.startswith(b"\x89PNG\r\n\x1a\n"): + return ".png" + if image_bytes.startswith(b"\xff\xd8\xff"): + return ".jpg" + if image_bytes.startswith((b"GIF87a", b"GIF89a")): + return ".gif" + if image_bytes.startswith(b"BM"): + return ".bmp" + if image_bytes.startswith(b"RIFF") and image_bytes[8:12] == b"WEBP": + return ".webp" + return ".png" + + +def _write_image_cache_file(image_bytes: bytes) -> str: + from Undefined.utils.paths import IMAGE_CACHE_DIR, ensure_dir + + suffix = _detect_image_suffix(image_bytes) + filename = f"ai_draw_{uuid.uuid4().hex[:8]}{suffix}" + filepath = ensure_dir(IMAGE_CACHE_DIR) / filename + with open(filepath, "wb") as f: + f.write(image_bytes) + return str(filepath) + + +async def _send_cached_image( + filepath: str, + target_id: int | str, + message_type: str, + context: dict[str, Any], +) -> str: + send_image_callback = context.get("send_image_callback") + if send_image_callback: + await send_image_callback(target_id, message_type, filepath) + return f"AI 绘图已发送给 {message_type} {target_id}" + return "发送图片回调未设置" + + async def _call_xingzhige(prompt: str, size: str, context: dict[str, Any]) -> str: """调用星之阁免费 API""" url = get_xingzhige_url("/API/DrawOne/") @@ -94,34 +255,36 @@ async def _call_openai_models( size: str, quality: str, style: str, + response_format: str, + n: int | None, timeout_val: float, context: dict[str, Any], -) -> str: +) -> _GeneratedImagePayload | str: """调用 OpenAI 兼容的图片生成接口""" - from Undefined.utils.request_params import merge_request_params - - # 构建请求 body(仅包含非空字段,其余由上游使用默认值) - body: dict[str, Any] = { - "prompt": prompt, - "n": 1, - } - if model_name: - body["model"] = model_name - if size: - body["size"] = size - if quality: - body["quality"] = quality - if style: - body["style"] = style # 追加 request_params + extra_params: dict[str, Any] = {} try: from Undefined.config import get_config extra_params = get_config(strict=False).models_image_gen.request_params - body = merge_request_params(body, extra_params) except Exception: - pass + extra_params = {} + + body = _build_openai_models_request_body( + prompt=prompt, + model_name=model_name, + size=size, + quality=quality, + style=style, + response_format=response_format, + n=n, + extra_params=extra_params, + ) + + validation_error = _validate_openai_models_request_body(body) + if validation_error: + return validation_error # 确保 base_url 末尾带 /v1 base_url = api_url.rstrip("/") @@ -149,14 +312,17 @@ async def _call_openai_models( except Exception: return f"API 返回错误 (非JSON): {response.text[:100]}" - image_url = _parse_image_url(data) - if image_url is None: + generated_image = _parse_generated_image(data) + if generated_image is None: logger.error(f"图片生成 API 返回 (未找到图片链接): {data}") - return f"API 返回原文 (错误:未找到图片链接): {data}" + return f"API 返回原文 (错误:未找到图片内容): {data}" logger.info(f"图片生成 API 返回: {data}") - logger.info(f"提取图片链接: {image_url}") - return image_url + if generated_image.image_url: + logger.info(f"提取图片链接: {generated_image.image_url}") + elif generated_image.image_bytes is not None: + logger.info("提取图片字节: bytes=%s", len(generated_image.image_bytes)) + return generated_image async def _download_and_send( @@ -173,20 +339,18 @@ async def _download_and_send( timeout=max(timeout_val, 15.0), context=context, ) + filepath = _write_image_cache_file(img_response.content) + return await _send_cached_image(filepath, target_id, message_type, context) - filename = f"ai_draw_{uuid.uuid4().hex[:8]}.jpg" - from Undefined.utils.paths import IMAGE_CACHE_DIR, ensure_dir - - filepath = ensure_dir(IMAGE_CACHE_DIR) / filename - with open(filepath, "wb") as f: - f.write(img_response.content) - - send_image_callback = context.get("send_image_callback") - if send_image_callback: - await send_image_callback(target_id, message_type, str(filepath)) - return f"AI 绘图已发送给 {message_type} {target_id}" - return "发送图片回调未设置" +async def _save_and_send( + image_bytes: bytes, + target_id: int | str, + message_type: str, + context: dict[str, Any], +) -> str: + filepath = _write_image_cache_file(image_bytes) + return await _send_cached_image(filepath, target_id, message_type, context) async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: @@ -195,6 +359,11 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: prompt_arg: str | None = args.get("prompt") size_arg: str | None = args.get("size") + model_arg: str | None = args.get("model") + quality_arg: str | None = args.get("quality") + style_arg: str | None = args.get("style") + response_format_arg: str | None = args.get("response_format") + n_arg = args.get("n") target_id: int | str | None = args.get("target_id") message_type_arg: str | None = args.get("message_type") @@ -206,22 +375,30 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: start_time = time.time() success = False used_model = provider + generated_result: str | _GeneratedImagePayload try: if provider == "xingzhige": prompt = prompt_arg or "" size = size_arg or cfg.xingzhige_size - image_url = await _call_xingzhige(prompt, size, context) + generated_result = await _call_xingzhige(prompt, size, context) elif provider == "models": prompt = prompt_arg or "" # 降级到 models.image_gen 配置,未填则降级到 chat_model api_url = gen_cfg.api_url or chat_cfg.api_url api_key = gen_cfg.api_key or chat_cfg.api_key - model_name = gen_cfg.model_name - size = size_arg or cfg.openai_size - quality = cfg.openai_quality - style = cfg.openai_style + model_name = str(model_arg or gen_cfg.model_name or "").strip() + size = str(size_arg or cfg.openai_size or "").strip() + quality = str(quality_arg or cfg.openai_quality or "").strip() + style = str(style_arg or cfg.openai_style or "").strip() + response_format = str(response_format_arg or "").strip() timeout_val = cfg.openai_timeout + n_value: int | None = None + if n_arg is not None and str(n_arg).strip() != "": + try: + n_value = int(n_arg) + except (TypeError, ValueError): + return f"n 无效:{n_arg}。必须是 1 到 10 的整数" if not api_url: return "图片生成失败:未配置 models.image_gen.api_url" @@ -229,7 +406,7 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: return "图片生成失败:未配置 models.image_gen.api_key" used_model = model_name or "openai-image-gen" - image_url = await _call_openai_models( + generated_result = await _call_openai_models( prompt=prompt, api_url=api_url, api_key=api_key, @@ -237,6 +414,8 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: size=size, quality=quality, style=style, + response_format=response_format, + n=n_value, timeout_val=timeout_val, context=context, ) @@ -246,17 +425,31 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: "请在 config.toml 中设置 image_gen.provider 为 xingzhige 或 models" ) - # 判断是否返回了错误消息(而非图片 URL) - if not image_url.startswith("http"): - return image_url + if isinstance(generated_result, _GeneratedImagePayload): + generated_image = generated_result + else: + if not generated_result.startswith("http"): + return generated_result + generated_image = _GeneratedImagePayload(image_url=generated_result) if target_id is None or message_type_arg is None: return "图片生成成功,但缺少发送目标参数" send_timeout = get_request_timeout(60.0) - result = await _download_and_send( - image_url, target_id, message_type_arg, send_timeout, context - ) + if generated_image.image_url: + result = await _download_and_send( + generated_image.image_url, + target_id, + message_type_arg, + send_timeout, + context, + ) + elif generated_image.image_bytes is not None: + result = await _save_and_send( + generated_image.image_bytes, target_id, message_type_arg, context + ) + else: + return "图片生成失败:未找到可发送的图片内容" success = True return result diff --git a/tests/test_ai_draw_one_handler.py b/tests/test_ai_draw_one_handler.py new file mode 100644 index 00000000..16d30c40 --- /dev/null +++ b/tests/test_ai_draw_one_handler.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import base64 +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import pytest + +from Undefined.skills.agents.entertainment_agent.tools.ai_draw_one import ( + handler as ai_draw_handler, +) + + +_PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n" + b"\x00\x00\x00\rIHDR" + b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00" + b"\x90wS\xde" + b"\x00\x00\x00\x0cIDATx\x9cc``\x00\x00\x00\x02\x00\x01" + b"\x0b\xe7\x02\x9d" + b"\x00\x00\x00\x00IEND\xaeB`\x82" +) + + +def _make_runtime_config(*, request_params: dict[str, Any] | None = None) -> Any: + return SimpleNamespace( + image_gen=SimpleNamespace( + provider="models", + openai_size="", + openai_quality="", + openai_style="", + openai_timeout=120.0, + ), + models_image_gen=SimpleNamespace( + api_url="https://image.example.com", + api_key="sk-image", + model_name="grok-imagine-1.0", + request_params=request_params or {}, + ), + chat_model=SimpleNamespace( + api_url="https://chat.example.com", + api_key="sk-chat", + ), + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response_key", ["b64_json", "base64"]) +async def test_execute_models_supports_base64_response_and_preserves_explicit_size( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + response_key: str, +) -> None: + runtime_config = _make_runtime_config( + request_params={ + "size": "1792x1024", + "response_format": "url", + } + ) + monkeypatch.setattr( + "Undefined.config.get_config", + lambda strict=False: runtime_config, + ) + monkeypatch.setattr("Undefined.utils.paths.IMAGE_CACHE_DIR", tmp_path) + + payload_base64 = base64.b64encode(_PNG_BYTES).decode("ascii") + seen_request: dict[str, Any] = {} + request_count = 0 + + class _FakeResponse: + text = "" + + def json(self) -> dict[str, Any]: + return {"data": [{response_key: payload_base64}]} + + async def _fake_request_with_retry( + method: str, + url: str, + **kwargs: Any, + ) -> _FakeResponse: + nonlocal request_count + request_count += 1 + seen_request["method"] = method + seen_request["url"] = url + seen_request["json_data"] = kwargs.get("json_data") + return _FakeResponse() + + sent: dict[str, Any] = {} + + async def _send_image( + target_id: int | str, + message_type: str, + file_path: str, + ) -> None: + sent["target_id"] = target_id + sent["message_type"] = message_type + sent["file_path"] = file_path + + monkeypatch.setattr(ai_draw_handler, "request_with_retry", _fake_request_with_retry) + + result = await ai_draw_handler.execute( + { + "prompt": "violet flowers", + "size": "1024x1024", + "response_format": response_key, + "target_id": 10001, + "message_type": "group", + }, + {"send_image_callback": _send_image}, + ) + + assert result == "AI 绘图已发送给 group 10001" + assert request_count == 1 + assert seen_request["method"] == "POST" + assert seen_request["json_data"]["size"] == "1024x1024" + assert seen_request["json_data"]["response_format"] == response_key + assert Path(sent["file_path"]).read_bytes() == _PNG_BYTES + + +@pytest.mark.asyncio +async def test_execute_models_rejects_invalid_size( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + "Undefined.config.get_config", + lambda strict=False: _make_runtime_config(), + ) + + result = await ai_draw_handler.execute( + { + "prompt": "violet flowers", + "size": "1:1", + "target_id": 10001, + "message_type": "group", + }, + {"send_image_callback": lambda *_args, **_kwargs: None}, + ) + + assert "size 无效" in result + assert "1024x1024" in result From 49ea6670f401b7d7c8a8ddc2f26d0613b37782ad Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Wed, 1 Apr 2026 00:15:32 +0800 Subject: [PATCH 03/21] fix(skills): register dynamic modules before exec --- src/Undefined/skills/registry.py | 9 +++++++- tests/test_agent_registry.py | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/Undefined/skills/registry.py b/src/Undefined/skills/registry.py index a197e4ac..4352a34d 100644 --- a/src/Undefined/skills/registry.py +++ b/src/Undefined/skills/registry.py @@ -254,7 +254,14 @@ def _load_handler_for_item( raise RuntimeError(f"加载处理器 spec 失败: {item.handler_path}") module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + sys.modules[item.module_name] = module + try: + spec.loader.exec_module(module) + except Exception: + current = sys.modules.get(item.module_name) + if current is module: + del sys.modules[item.module_name] + raise if not hasattr(module, "execute"): raise RuntimeError(f"{item.handler_path} 的处理器缺少 'execute' 函数") diff --git a/tests/test_agent_registry.py b/tests/test_agent_registry.py index 77989a2d..3c8e2f08 100644 --- a/tests/test_agent_registry.py +++ b/tests/test_agent_registry.py @@ -49,3 +49,39 @@ async def _wait_for(awaitable, timeout): # type: ignore[no-untyped-def] assert result == "ok" assert registry.timeout_seconds == 0.0 assert "timeout" not in seen + + +@pytest.mark.asyncio +async def test_agent_registry_loads_handler_with_dataclass(tmp_path: Path) -> None: + agent_dir = tmp_path / "demo_agent" + agent_dir.mkdir() + (agent_dir / "config.json").write_text( + json.dumps( + { + "type": "function", + "function": { + "name": "demo_agent", + "description": "demo", + "parameters": {"type": "object", "properties": {}}, + }, + } + ), + encoding="utf-8", + ) + (agent_dir / "handler.py").write_text( + "from dataclasses import dataclass\n" + "\n" + "@dataclass\n" + "class Payload:\n" + " value: str = 'ok'\n" + "\n" + "async def execute(args, context):\n" + " return Payload().value\n", + encoding="utf-8", + ) + + registry = AgentRegistry(tmp_path) + + result = await registry.execute_agent("demo_agent", {}, {}) + + assert result == "ok" From b8477b9a636ae895c9f12195e1d5b605b729d15e Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Wed, 1 Apr 2026 10:45:55 +0800 Subject: [PATCH 04/21] fix(image-gen): expose upstream errors and lock model --- .../tools/ai_draw_one/config.json | 4 - .../tools/ai_draw_one/handler.py | 54 +++++++++-- tests/test_ai_draw_one_handler.py | 96 +++++++++++++++++++ 3 files changed, 140 insertions(+), 14 deletions(-) diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json index ebce1dda..c309fa93 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json @@ -10,10 +10,6 @@ "type": "string", "description": "绘图提示词" }, - "model": { - "type": "string", - "description": "绘图模型 (例如: anylora, dall-e-3, minimax-image-01;可选,默认由配置决定)" - }, "size": { "type": "string", "description": "比例/尺寸。星之阁模式可用 1:1 等比例;models 模式仅支持 1280x720、720x1280、1792x1024、1024x1792、1024x1024" diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py index 70b9aefb..bcefe674 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py @@ -15,6 +15,8 @@ from dataclasses import dataclass from typing import Any +import httpx + from Undefined.skills.http_client import request_with_retry from Undefined.skills.http_config import get_request_timeout, get_xingzhige_url @@ -120,6 +122,31 @@ def _coerce_bool(value: Any) -> bool: return str(value or "").strip().lower() in {"1", "true", "yes", "on"} +def _format_upstream_error_message(response: httpx.Response) -> str: + default_message = response.text[:200] or f"HTTP {response.status_code}" + try: + data = response.json() + except Exception: + return default_message + + if not isinstance(data, dict): + return default_message + + error = data.get("error") + if isinstance(error, dict): + code = str(error.get("code", "") or "").strip() + message = str(error.get("message", "") or "").strip() + if code and message: + return f"{code}: {message}" + if message: + return message + if code: + return code + + message = str(data.get("message", "") or "").strip() + return message or default_message + + def _build_openai_models_request_body( *, prompt: str, @@ -298,14 +325,22 @@ async def _call_openai_models( if api_key: headers["Authorization"] = f"Bearer {api_key}" - response = await request_with_retry( - "POST", - url, - json_data=body, - headers=headers, - timeout=timeout_val, - context=context, - ) + try: + response = await request_with_retry( + "POST", + url, + json_data=body, + headers=headers, + timeout=timeout_val, + context=context, + ) + except httpx.HTTPStatusError as exc: + message = _format_upstream_error_message(exc.response) + return f"图片生成请求失败: HTTP {exc.response.status_code} {message}" + except httpx.TimeoutException: + return f"图片生成请求超时({timeout_val:.0f}s)" + except httpx.RequestError as exc: + return f"图片生成请求失败: {exc}" try: data = response.json() @@ -359,7 +394,6 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: prompt_arg: str | None = args.get("prompt") size_arg: str | None = args.get("size") - model_arg: str | None = args.get("model") quality_arg: str | None = args.get("quality") style_arg: str | None = args.get("style") response_format_arg: str | None = args.get("response_format") @@ -387,7 +421,7 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: # 降级到 models.image_gen 配置,未填则降级到 chat_model api_url = gen_cfg.api_url or chat_cfg.api_url api_key = gen_cfg.api_key or chat_cfg.api_key - model_name = str(model_arg or gen_cfg.model_name or "").strip() + model_name = str(gen_cfg.model_name or "").strip() size = str(size_arg or cfg.openai_size or "").strip() quality = str(quality_arg or cfg.openai_quality or "").strip() style = str(style_arg or cfg.openai_style or "").strip() diff --git a/tests/test_ai_draw_one_handler.py b/tests/test_ai_draw_one_handler.py index 16d30c40..0b66f212 100644 --- a/tests/test_ai_draw_one_handler.py +++ b/tests/test_ai_draw_one_handler.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from typing import Any +import httpx import pytest from Undefined.skills.agents.entertainment_agent.tools.ai_draw_one import ( @@ -139,3 +140,98 @@ async def test_execute_models_rejects_invalid_size( assert "size 无效" in result assert "1024x1024" in result + + +@pytest.mark.asyncio +async def test_execute_models_reports_upstream_http_error_detail( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + "Undefined.config.get_config", + lambda strict=False: _make_runtime_config(), + ) + + request = httpx.Request("POST", "https://image.example.com/v1/images/generations") + response = httpx.Response( + 503, + request=request, + json={ + "error": { + "code": "upstream_error", + "message": "Image generation blocked or no valid final image", + } + }, + ) + + async def _fake_request_with_retry(*_args: Any, **_kwargs: Any) -> Any: + raise httpx.HTTPStatusError("boom", request=request, response=response) + + monkeypatch.setattr(ai_draw_handler, "request_with_retry", _fake_request_with_retry) + + result = await ai_draw_handler.execute( + { + "prompt": "violet flowers", + "size": "1024x1024", + "target_id": 10001, + "message_type": "group", + }, + {"send_image_callback": lambda *_args, **_kwargs: None}, + ) + + assert "HTTP 503" in result + assert "upstream_error" in result + assert "Image generation blocked or no valid final image" in result + + +@pytest.mark.asyncio +async def test_execute_models_uses_configured_model_only( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + monkeypatch.setattr( + "Undefined.config.get_config", + lambda strict=False: _make_runtime_config(), + ) + monkeypatch.setattr("Undefined.utils.paths.IMAGE_CACHE_DIR", tmp_path) + + payload_base64 = base64.b64encode(_PNG_BYTES).decode("ascii") + seen_request: dict[str, Any] = {} + + class _FakeResponse: + text = "" + + def json(self) -> dict[str, Any]: + return {"data": [{"base64": payload_base64}]} + + async def _fake_request_with_retry( + method: str, + url: str, + **kwargs: Any, + ) -> _FakeResponse: + seen_request["method"] = method + seen_request["url"] = url + seen_request["json_data"] = kwargs.get("json_data") + return _FakeResponse() + + async def _send_image( + target_id: int | str, + message_type: str, + file_path: str, + ) -> None: + _ = target_id, message_type, file_path + + monkeypatch.setattr(ai_draw_handler, "request_with_retry", _fake_request_with_retry) + + result = await ai_draw_handler.execute( + { + "prompt": "violet flowers", + "model": "dall-e-3", + "size": "1024x1024", + "target_id": 10001, + "message_type": "group", + }, + {"send_image_callback": _send_image}, + ) + + assert result == "AI 绘图已发送给 group 10001" + assert seen_request["json_data"]["model"] == "grok-imagine-1.0" From 03ce712391fb079e8c6650de5806728378ebeb98 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Wed, 1 Apr 2026 14:42:49 +0800 Subject: [PATCH 05/21] fix(image-gen): default responses to base64 --- .../tools/ai_draw_one/handler.py | 2 + tests/test_ai_draw_one_handler.py | 53 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py index bcefe674..6215a46a 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py @@ -176,6 +176,8 @@ def _build_openai_models_request_body( body["style"] = style if response_format: body["response_format"] = response_format + else: + body.setdefault("response_format", "base64") return body diff --git a/tests/test_ai_draw_one_handler.py b/tests/test_ai_draw_one_handler.py index 0b66f212..d8c09262 100644 --- a/tests/test_ai_draw_one_handler.py +++ b/tests/test_ai_draw_one_handler.py @@ -142,6 +142,59 @@ async def test_execute_models_rejects_invalid_size( assert "1024x1024" in result +@pytest.mark.asyncio +async def test_execute_models_defaults_response_format_to_base64( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + monkeypatch.setattr( + "Undefined.config.get_config", + lambda strict=False: _make_runtime_config(request_params={}), + ) + monkeypatch.setattr("Undefined.utils.paths.IMAGE_CACHE_DIR", tmp_path) + + payload_base64 = base64.b64encode(_PNG_BYTES).decode("ascii") + seen_request: dict[str, Any] = {} + + class _FakeResponse: + text = "" + + def json(self) -> dict[str, Any]: + return {"data": [{"base64": payload_base64}]} + + async def _fake_request_with_retry( + method: str, + url: str, + **kwargs: Any, + ) -> _FakeResponse: + seen_request["method"] = method + seen_request["url"] = url + seen_request["json_data"] = kwargs.get("json_data") + return _FakeResponse() + + async def _send_image( + target_id: int | str, + message_type: str, + file_path: str, + ) -> None: + _ = target_id, message_type, file_path + + monkeypatch.setattr(ai_draw_handler, "request_with_retry", _fake_request_with_retry) + + result = await ai_draw_handler.execute( + { + "prompt": "violet flowers", + "size": "1024x1024", + "target_id": 10001, + "message_type": "group", + }, + {"send_image_callback": _send_image}, + ) + + assert result == "AI 绘图已发送给 group 10001" + assert seen_request["json_data"]["response_format"] == "base64" + + @pytest.mark.asyncio async def test_execute_models_reports_upstream_http_error_detail( monkeypatch: pytest.MonkeyPatch, From a4f402d7bd97bcbb894b2d065af626ddc63e95e4 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Wed, 1 Apr 2026 16:35:09 +0800 Subject: [PATCH 06/21] feat(multimodal): add attachment uid registry and pic embeds --- res/prompts/undefined.xml | 9 + res/prompts/undefined_nagaagent.xml | 9 + src/Undefined/ai/client.py | 6 + src/Undefined/ai/prompts.py | 13 +- src/Undefined/api/app.py | 78 +- src/Undefined/attachments.py | 785 ++++++++++++++++++ src/Undefined/handlers.py | 52 ++ src/Undefined/services/ai_coordinator.py | 36 +- .../agents/entertainment_agent/prompt.md | 2 + .../tools/ai_draw_one/config.json | 11 +- .../tools/ai_draw_one/handler.py | 111 ++- .../agents/file_analysis_agent/config.json | 4 +- .../agents/file_analysis_agent/handler.py | 21 +- .../agents/file_analysis_agent/prompt.md | 5 + .../tools/download_file/config.json | 4 +- .../tools/download_file/handler.py | 63 ++ .../toolsets/messages/send_message/handler.py | 28 +- .../messages/send_private_message/handler.py | 27 +- src/Undefined/utils/history.py | 36 + src/Undefined/utils/paths.py | 2 + src/Undefined/utils/sender.py | 16 +- src/Undefined/webui/static/js/runtime.js | 6 + tests/test_ai_draw_one_handler.py | 110 +++ tests/test_attachments.py | 99 +++ tests/test_file_analysis_attachment_uid.py | 41 + tests/test_send_message_tool.py | 50 +- tests/test_send_private_message_tool.py | 6 + 27 files changed, 1596 insertions(+), 34 deletions(-) create mode 100644 src/Undefined/attachments.py create mode 100644 tests/test_attachments.py create mode 100644 tests/test_file_analysis_attachment_uid.py diff --git a/res/prompts/undefined.xml b/res/prompts/undefined.xml index 95ed07a4..cc6ec7d2 100644 --- a/res/prompts/undefined.xml +++ b/res/prompts/undefined.xml @@ -131,6 +131,15 @@ **关键点:每次消息处理都必须以 end 结束,这是维持对话流的核心机制。** + + + **图文混排规则:** + - 如果上下文或工具结果给了图片 UID(例如 `pic_ab12cd34`),你可以在 `send_message.message` 里直接插入 `` + - `` 是唯一允许的内嵌图片语法;不要改成 Markdown 图片、HTML ``、代码块或自然语言描述 + - 可以图文混排,例如:`我给你介绍一下`\n``\n`如图所示` + - 只能引用当前会话里明确给出的图片 UID,禁止臆造 UID + - 只有 `pic_*` 这类图片 UID 能放进 ``;普通文件 UID 不能放进去 + diff --git a/res/prompts/undefined_nagaagent.xml b/res/prompts/undefined_nagaagent.xml index 977c744c..eaedb148 100644 --- a/res/prompts/undefined_nagaagent.xml +++ b/res/prompts/undefined_nagaagent.xml @@ -131,6 +131,15 @@ **关键点:每次消息处理都必须以 end 结束,这是维持对话流的核心机制。** + + + **图文混排规则:** + - 如果上下文或工具结果给了图片 UID(例如 `pic_ab12cd34`),你可以在 `send_message.message` 里直接插入 `` + - `` 是唯一允许的内嵌图片语法;不要改成 Markdown 图片、HTML ``、代码块或自然语言描述 + - 可以图文混排,例如:`我给你介绍一下`\n``\n`如图所示` + - 只能引用当前会话里明确给出的图片 UID,禁止臆造 UID + - 只有 `pic_*` 这类图片 UID 能放进 ``;普通文件 UID 不能放进去 + diff --git a/src/Undefined/ai/client.py b/src/Undefined/ai/client.py index 7c526ce6..7c6beb3a 100644 --- a/src/Undefined/ai/client.py +++ b/src/Undefined/ai/client.py @@ -12,6 +12,7 @@ import httpx +from Undefined.attachments import AttachmentRegistry from Undefined.ai.llm import ModelRequester from Undefined.ai.model_selector import ModelSelector from Undefined.ai.multimodal import MultimodalAnalyzer @@ -136,6 +137,7 @@ def __init__( self._token_counter = TokenCounter() self._knowledge_manager: Any = None self._cognitive_service: Any = cognitive_service + self.attachment_registry = AttachmentRegistry(http_client=self._http_client) # 私聊发送回调 self._send_private_message_callback: Optional[SendPrivateMessageCallback] = None @@ -973,6 +975,10 @@ async def ask( tool_context.setdefault("onebot_client", onebot_client) tool_context.setdefault("scheduler", scheduler) tool_context.setdefault("send_image_callback", self._send_image_callback) + tool_context.setdefault( + "attachment_registry", + getattr(self, "attachment_registry", None), + ) tool_context.setdefault("memory_storage", self.memory_storage) tool_context.setdefault("knowledge_manager", self._knowledge_manager) tool_context.setdefault("cognitive_service", self._cognitive_service) diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index 2721efce..d9fc16d1 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -11,6 +11,7 @@ import aiofiles +from Undefined.attachments import attachment_refs_to_xml from Undefined.context import RequestContext from Undefined.end_summary_storage import ( EndSummaryStorage, @@ -26,7 +27,7 @@ logger = logging.getLogger(__name__) _CURRENT_MESSAGE_RE = re.compile( - r"[^>]*)>\s*(?P.*?)\s*", + r"[^>]*)>.*?(?P.*?).*?", re.DOTALL | re.IGNORECASE, ) _XML_ATTR_RE = re.compile(r'(?P[a-zA-Z_][a-zA-Z0-9_-]*)="(?P[^"]*)"') @@ -648,6 +649,7 @@ async def _inject_recent_messages( chat_name = msg.get("chat_name", "未知群聊") timestamp = msg.get("timestamp", "") text = msg.get("message", "") + attachments = msg.get("attachments", []) role = msg.get("role", "member") title = msg.get("title", "") message_id = msg.get("message_id") @@ -664,6 +666,11 @@ async def _inject_recent_messages( msg_id_attr = "" if message_id is not None: msg_id_attr = f' message_id="{escape_xml_attr(str(message_id))}"' + attachment_xml = ( + f"\n{attachment_refs_to_xml(attachments)}" + if isinstance(attachments, list) and attachments + else "" + ) if msg_type_val == "group": location = ( @@ -673,14 +680,14 @@ async def _inject_recent_messages( xml_msg = ( f'\n{safe_text}\n' + f'time="{safe_time}">\n{safe_text}{attachment_xml}\n' ) else: location = "私聊" safe_location = escape_xml_attr(location) xml_msg = ( f'\n{safe_text}\n' + f'time="{safe_time}">\n{safe_text}{attachment_xml}\n' ) context_lines.append(xml_msg) diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index ea33b8b2..1381b542 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -21,11 +21,18 @@ from aiohttp.web_response import Response from Undefined import __version__ +from Undefined.attachments import ( + attachment_refs_to_xml, + build_attachment_scope, + register_message_attachments, + render_message_with_pic_placeholders, +) from Undefined.config import load_webui_settings from Undefined.context import RequestContext from Undefined.context_resource_registry import collect_context_resources from Undefined.render import render_html_to_image, render_markdown_to_html # noqa: F401 from Undefined.services.queue_manager import QUEUE_LANE_SUPERADMIN +from Undefined.utils.common import message_to_segments from Undefined.utils.cors import is_allowed_cors_origin, normalize_origin from Undefined.utils.recent_messages import get_recent_messages_prefer_local from Undefined.utils.xml import escape_xml_attr, escape_xml_text @@ -75,8 +82,16 @@ async def send_private_message( mark_sent: bool = True, reply_to: int | None = None, preferred_temp_group_id: int | None = None, + history_message: str | None = None, ) -> int | None: - _ = user_id, auto_history, mark_sent, reply_to, preferred_temp_group_id + _ = ( + user_id, + auto_history, + mark_sent, + reply_to, + preferred_temp_group_id, + history_message, + ) await self._send_private_callback(self._virtual_user_id, message) return None @@ -89,8 +104,16 @@ async def send_group_message( *, mark_sent: bool = True, reply_to: int | None = None, + history_message: str | None = None, ) -> int | None: - _ = group_id, auto_history, history_prefix, mark_sent, reply_to + _ = ( + group_id, + auto_history, + history_prefix, + mark_sent, + reply_to, + history_message, + ) await self._send_private_callback(self._virtual_user_id, message) return None @@ -1181,14 +1204,28 @@ async def _run_webui_chat( ) -> str: cfg = self._ctx.config_getter() permission_sender_id = int(cfg.superadmin_qq) + webui_scope_key = build_attachment_scope( + user_id=_VIRTUAL_USER_ID, + request_type="private", + webui_session=True, + ) + input_segments = message_to_segments(text) + registered_input = await register_message_attachments( + registry=self._ctx.ai.attachment_registry, + segments=input_segments, + scope_key=webui_scope_key, + resolve_image_url=self._ctx.onebot.get_image, + ) + normalized_text = registered_input.normalized_text or text await self._ctx.history_manager.add_private_message( user_id=_VIRTUAL_USER_ID, - text_content=text, + text_content=normalized_text, display_name=_VIRTUAL_USER_NAME, user_name=_VIRTUAL_USER_NAME, + attachments=registered_input.attachments, ) - command = self._ctx.command_dispatcher.parse_command(text) + command = self._ctx.command_dispatcher.parse_command(normalized_text) if command: await self._ctx.command_dispatcher.dispatch_private( user_id=_VIRTUAL_USER_ID, @@ -1200,8 +1237,13 @@ async def _run_webui_chat( return "command" current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + attachment_xml = ( + f"\n{attachment_refs_to_xml(registered_input.attachments)}" + if registered_input.attachments + else "" + ) full_question = f""" - {escape_xml_text(text)} + {escape_xml_text(normalized_text)}{attachment_xml} 【WebUI 会话】 @@ -1334,18 +1376,32 @@ async def _chat_handler(self, request: web.Request) -> web.StreamResponse: stream = _to_bool(body.get("stream")) outputs: list[str] = [] + webui_scope_key = build_attachment_scope( + user_id=_VIRTUAL_USER_ID, + request_type="private", + webui_session=True, + ) async def _capture_private_message(user_id: int, message: str) -> None: _ = user_id content = str(message or "").strip() if not content: return - outputs.append(content) + rendered = await render_message_with_pic_placeholders( + content, + registry=self._ctx.ai.attachment_registry, + scope_key=webui_scope_key, + strict=False, + ) + if not rendered.delivery_text.strip(): + return + outputs.append(rendered.delivery_text) await self._ctx.history_manager.add_private_message( user_id=_VIRTUAL_USER_ID, - text_content=content, + text_content=rendered.history_text, display_name="Bot", user_name="Bot", + attachments=rendered.attachments, ) if not stream: @@ -1373,7 +1429,13 @@ async def _capture_private_message(user_id: int, message: str) -> None: async def _capture_private_message_stream(user_id: int, message: str) -> None: await _capture_private_message(user_id, message) - content = str(message or "").strip() + rendered = await render_message_with_pic_placeholders( + str(message or "").strip(), + registry=self._ctx.ai.attachment_registry, + scope_key=webui_scope_key, + strict=False, + ) + content = rendered.delivery_text.strip() if content: await message_queue.put(content) diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py new file mode 100644 index 00000000..a8f187ac --- /dev/null +++ b/src/Undefined/attachments.py @@ -0,0 +1,785 @@ +"""Attachment registry and rich-media helpers.""" + +from __future__ import annotations + +import asyncio +import base64 +import binascii +from dataclasses import asdict, dataclass +from datetime import datetime +import hashlib +import json +import logging +import mimetypes +from pathlib import Path +import re +from typing import Any, Awaitable, Callable, Mapping, Sequence +from urllib.parse import unquote, urlsplit + +import httpx + +from Undefined.utils import io +from Undefined.utils.paths import ( + ATTACHMENT_CACHE_DIR, + ATTACHMENT_REGISTRY_FILE, + WEBUI_FILE_CACHE_DIR, + ensure_dir, +) +from Undefined.utils.xml import escape_xml_attr + +logger = logging.getLogger(__name__) + +_PIC_TAG_PATTERN = re.compile( + r"[\"'])(?P[^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) +_MEDIA_LABELS = { + "image": "图片", + "file": "文件", + "audio": "音频", + "video": "视频", + "record": "语音", +} +_WINDOWS_ABS_PATH_RE = re.compile(r"^[A-Za-z]:[\\/]") +_DEFAULT_REMOTE_TIMEOUT_SECONDS = 120.0 +_IMAGE_SUFFIX_TO_MIME = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", + ".bmp": "image/bmp", + ".svg": "image/svg+xml", +} +_MAGIC_IMAGE_SUFFIXES: tuple[tuple[bytes, str], ...] = ( + (b"\x89PNG\r\n\x1a\n", ".png"), + (b"\xff\xd8\xff", ".jpg"), + (b"GIF87a", ".gif"), + (b"GIF89a", ".gif"), + (b"BM", ".bmp"), +) + + +@dataclass(frozen=True) +class AttachmentRecord: + uid: str + scope_key: str + kind: str + media_type: str + display_name: str + source_kind: str + source_ref: str + local_path: str | None + mime_type: str + sha256: str + created_at: str + + def prompt_ref(self) -> dict[str, str]: + return { + "uid": self.uid, + "kind": self.kind, + "media_type": self.media_type, + "display_name": self.display_name, + } + + +@dataclass(frozen=True) +class RegisteredMessageAttachments: + attachments: list[dict[str, str]] + normalized_text: str + + +@dataclass(frozen=True) +class RenderedRichMessage: + delivery_text: str + history_text: str + attachments: list[dict[str, str]] + + +class AttachmentRenderError(RuntimeError): + """Raised when a `` tag cannot be rendered.""" + + +def _now_iso() -> str: + return datetime.now().isoformat(timespec="seconds") + + +def _coerce_positive_int(value: Any) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value if value > 0 else None + if isinstance(value, str): + text = value.strip() + if not text: + return None + try: + parsed = int(text) + except ValueError: + return None + return parsed if parsed > 0 else None + return None + + +def build_attachment_scope( + *, + group_id: Any = None, + user_id: Any = None, + request_type: str | None = None, + webui_session: bool = False, +) -> str | None: + """Build a scope key for attachment visibility.""" + if webui_session: + return "webui" + + group = _coerce_positive_int(group_id) + if group is not None: + return f"group:{group}" + + user = _coerce_positive_int(user_id) + request_type_text = str(request_type or "").strip().lower() + if request_type_text == "private" and user is not None: + return f"private:{user}" + if request_type_text == "group" and group is not None: + return f"group:{group}" + if user is not None: + return f"private:{user}" + return None + + +def scope_from_context(context: Mapping[str, Any] | None) -> str | None: + if not context: + return None + return build_attachment_scope( + group_id=context.get("group_id"), + user_id=context.get("user_id"), + request_type=str(context.get("request_type", "") or ""), + webui_session=bool(context.get("webui_session", False)), + ) + + +def attachment_refs_to_text(attachments: Sequence[Mapping[str, str]]) -> str: + if not attachments: + return "" + parts: list[str] = [] + for item in attachments: + uid = str(item.get("uid", "") or "").strip() + if not uid: + continue + media_type = str(item.get("media_type") or item.get("kind") or "file").strip() + label = _MEDIA_LABELS.get(media_type, "附件") + name = str(item.get("display_name", "") or "").strip() + if name: + parts.append(f"[{label} uid={uid} name={name}]") + else: + parts.append(f"[{label} uid={uid}]") + return " ".join(parts) + + +def attachment_refs_to_xml( + attachments: Sequence[Mapping[str, str]], + *, + indent: str = " ", +) -> str: + if not attachments: + return "" + lines = [f"{indent}"] + for item in attachments: + uid = str(item.get("uid", "") or "").strip() + if not uid: + continue + kind = str(item.get("kind", "") or item.get("media_type", "") or "file").strip() + media_type = str(item.get("media_type", "") or kind or "file").strip() + name = str(item.get("display_name", "") or "").strip() + attrs = [ + f'uid="{escape_xml_attr(uid)}"', + f'type="{escape_xml_attr(kind or media_type)}"', + f'media_type="{escape_xml_attr(media_type)}"', + ] + if name: + attrs.append(f'name="{escape_xml_attr(name)}"') + lines.append(f"{indent} ") + lines.append(f"{indent}") + return "\n".join(lines) + + +def append_attachment_text( + base_text: str, attachments: Sequence[Mapping[str, str]] +) -> str: + attachment_text = attachment_refs_to_text(attachments) + if not attachment_text: + return base_text + if not base_text.strip(): + return attachment_text + return f"{base_text}\n附件: {attachment_text}" + + +def _is_http_url(value: str) -> bool: + return value.startswith("http://") or value.startswith("https://") + + +def _is_data_url(value: str) -> bool: + return value.startswith("data:") + + +def _is_localish_path(value: str) -> bool: + return ( + value.startswith("/") + or value.startswith("file://") + or bool(_WINDOWS_ABS_PATH_RE.match(value)) + ) + + +def _decode_data_url(data_url: str) -> tuple[bytes, str]: + header, _, payload = data_url.partition(",") + if ";base64" not in header.lower(): + raise ValueError("unsupported data URL encoding") + mime_type = ( + header.split(":", 1)[1].split(";", 1)[0].strip() or "application/octet-stream" + ) + return base64.b64decode(payload), mime_type + + +def _guess_suffix_from_bytes(content: bytes) -> str: + for magic, suffix in _MAGIC_IMAGE_SUFFIXES: + if content.startswith(magic): + return suffix + if content.startswith(b"RIFF") and content[8:12] == b"WEBP": + return ".webp" + return ".bin" + + +def _guess_suffix(name: str, content: bytes, mime_type: str) -> str: + suffix = Path(name).suffix.lower() + if suffix: + return suffix + guessed_ext = mimetypes.guess_extension(mime_type or "") + if guessed_ext: + return guessed_ext.lower() + return _guess_suffix_from_bytes(content) + + +def _guess_mime_type(name: str, content: bytes) -> str: + guessed, _ = mimetypes.guess_type(name) + if guessed: + return guessed + suffix = _guess_suffix_from_bytes(content) + return _IMAGE_SUFFIX_TO_MIME.get(suffix, "application/octet-stream") + + +def _display_name_from_source(raw_source: str, fallback: str) -> str: + if not raw_source: + return fallback + if raw_source.startswith("file://"): + raw_source = raw_source[7:] + name = Path(unquote(urlsplit(raw_source).path)).name + return name or fallback + + +def _media_kind_from_value(value: str) -> str: + text = str(value or "").strip().lower() + if text in {"image", "file", "audio", "video", "record"}: + return text + return "file" + + +def _segment_text( + type_: str, data: Mapping[str, Any], ref: Mapping[str, str] | None +) -> str: + if type_ == "text": + return str(data.get("text", "") or "") + if type_ == "at": + qq = str(data.get("qq", "") or "").strip() + name = str(data.get("name") or data.get("nickname") or "").strip() + if qq and name: + return f"[@{qq}({name})]" + if qq: + return f"[@{qq}]" + return "[@]" + if type_ == "face": + return "[表情]" + if type_ == "reply": + reply_id = str(data.get("id") or data.get("message_id") or "").strip() + return f"[引用: {reply_id}]" if reply_id else "[引用]" + if type_ == "forward": + forward_id = str(data.get("id") or data.get("resid") or "").strip() + return f"[合并转发: {forward_id}]" if forward_id else "[合并转发]" + if ref is not None: + label = _MEDIA_LABELS.get( + str(ref.get("media_type") or ref.get("kind") or type_).strip(), "附件" + ) + uid = str(ref.get("uid", "") or "").strip() + name = str(ref.get("display_name", "") or "").strip() + if uid and name: + return f"[{label} uid={uid} name={name}]" + if uid: + return f"[{label} uid={uid}]" + label = _MEDIA_LABELS.get(type_, "附件") + raw = str(data.get("file") or data.get("url") or data.get("id") or "").strip() + return f"[{label}: {raw}]" if raw else f"[{label}]" + + +def _resolve_webui_file_id(file_id: str) -> Path | None: + if not file_id or not file_id.isalnum(): + return None + file_dir = (Path.cwd() / WEBUI_FILE_CACHE_DIR / file_id).resolve() + cache_root = (Path.cwd() / WEBUI_FILE_CACHE_DIR).resolve() + if cache_root not in file_dir.parents and file_dir != cache_root: + return None + if not file_dir.is_dir(): + return None + try: + files = list(file_dir.iterdir()) + except OSError: + return None + for candidate in files: + if candidate.is_file(): + return candidate + return None + + +class AttachmentRegistry: + """Persistent attachment registry scoped by conversation.""" + + def __init__( + self, + *, + registry_path: Path = ATTACHMENT_REGISTRY_FILE, + cache_dir: Path = ATTACHMENT_CACHE_DIR, + http_client: httpx.AsyncClient | None = None, + ) -> None: + self._registry_path = registry_path + self._cache_dir = cache_dir + self._http_client = http_client + self._lock = asyncio.Lock() + self._records: dict[str, AttachmentRecord] = {} + self._load_from_disk() + + def _load_from_disk(self) -> None: + if not self._registry_path.exists(): + return + try: + raw = json.loads(self._registry_path.read_text(encoding="utf-8")) + except Exception as exc: + logger.warning("[AttachmentRegistry] 读取失败: %s", exc) + return + if not isinstance(raw, dict): + return + loaded: dict[str, AttachmentRecord] = {} + for uid, item in raw.items(): + if not isinstance(item, dict): + continue + try: + loaded[str(uid)] = AttachmentRecord( + uid=str(item.get("uid") or uid), + scope_key=str(item.get("scope_key", "") or ""), + kind=_media_kind_from_value(item.get("kind", "file")), + media_type=_media_kind_from_value( + item.get("media_type") or item.get("kind") or "file" + ), + display_name=str(item.get("display_name", "") or ""), + source_kind=str(item.get("source_kind", "") or ""), + source_ref=str(item.get("source_ref", "") or ""), + local_path=str(item.get("local_path", "") or "") or None, + mime_type=str( + item.get("mime_type", "") or "application/octet-stream" + ), + sha256=str(item.get("sha256", "") or ""), + created_at=str(item.get("created_at", "") or ""), + ) + except Exception: + continue + self._records = loaded + + async def _persist(self) -> None: + payload = {uid: asdict(record) for uid, record in self._records.items()} + await io.write_json(self._registry_path, payload, use_lock=True) + + def get(self, uid: str) -> AttachmentRecord | None: + return self._records.get(str(uid).strip()) + + def resolve(self, uid: str, scope_key: str | None) -> AttachmentRecord | None: + record = self.get(uid) + if record is None: + return None + if scope_key and record.scope_key != scope_key: + return None + return record + + def resolve_for_context( + self, + uid: str, + context: Mapping[str, Any] | None, + ) -> AttachmentRecord | None: + return self.resolve(uid, scope_from_context(context)) + + def _build_uid(self, prefix: str) -> str: + from uuid import uuid4 + + while True: + uid = f"{prefix}_{uuid4().hex[:8]}" + if uid not in self._records: + return uid + + async def register_bytes( + self, + scope_key: str, + content: bytes, + *, + kind: str, + display_name: str, + source_kind: str, + source_ref: str = "", + mime_type: str | None = None, + ) -> AttachmentRecord: + normalized_kind = _media_kind_from_value(kind) + normalized_media_type = ( + "image" if normalized_kind == "image" else normalized_kind + ) + normalized_mime = mime_type or _guess_mime_type(display_name, content) + suffix = _guess_suffix(display_name, content, normalized_mime) + prefix = "pic" if normalized_media_type == "image" else "file" + + async with self._lock: + uid = self._build_uid(prefix) + file_name = f"{uid}{suffix}" + cache_path = ensure_dir(self._cache_dir) / file_name + + def _write() -> str: + cache_path.write_bytes(content) + return hashlib.sha256(content).hexdigest() + + digest = await asyncio.to_thread(_write) + record = AttachmentRecord( + uid=uid, + scope_key=scope_key, + kind=normalized_kind, + media_type=normalized_media_type, + display_name=display_name or file_name, + source_kind=source_kind, + source_ref=source_ref, + local_path=str(cache_path), + mime_type=normalized_mime, + sha256=digest, + created_at=_now_iso(), + ) + self._records[uid] = record + await self._persist() + return record + + async def register_local_file( + self, + scope_key: str, + local_path: str | Path, + *, + kind: str, + display_name: str | None = None, + source_kind: str = "local_file", + source_ref: str = "", + ) -> AttachmentRecord: + path = Path(str(local_path)).expanduser() + if not path.is_absolute(): + path = (Path.cwd() / path).resolve() + else: + path = path.resolve() + if not path.is_file(): + raise FileNotFoundError(path) + + def _read() -> bytes: + return path.read_bytes() + + content = await asyncio.to_thread(_read) + return await self.register_bytes( + scope_key, + content, + kind=kind, + display_name=display_name or path.name, + source_kind=source_kind, + source_ref=source_ref or str(path), + mime_type=mimetypes.guess_type(path.name)[0] or None, + ) + + async def register_data_url( + self, + scope_key: str, + data_url: str, + *, + kind: str, + display_name: str, + source_kind: str, + source_ref: str = "", + ) -> AttachmentRecord: + content, mime_type = _decode_data_url(data_url) + return await self.register_bytes( + scope_key, + content, + kind=kind, + display_name=display_name, + source_kind=source_kind, + source_ref=source_ref, + mime_type=mime_type, + ) + + async def register_remote_url( + self, + scope_key: str, + url: str, + *, + kind: str, + display_name: str | None = None, + source_kind: str = "remote_url", + source_ref: str = "", + ) -> AttachmentRecord: + timeout = httpx.Timeout(_DEFAULT_REMOTE_TIMEOUT_SECONDS) + if self._http_client is not None: + response = await self._http_client.get( + url, timeout=timeout, follow_redirects=True + ) + else: + async with httpx.AsyncClient( + timeout=timeout, follow_redirects=True + ) as client: + response = await client.get(url) + response.raise_for_status() + name = display_name or _display_name_from_source(url, "attachment.bin") + mime_type = response.headers.get("content-type", "").split(";", 1)[0].strip() + return await self.register_bytes( + scope_key, + response.content, + kind=kind, + display_name=name, + source_kind=source_kind, + source_ref=source_ref or url, + mime_type=mime_type or None, + ) + + +async def register_message_attachments( + *, + registry: AttachmentRegistry | None, + segments: Sequence[Mapping[str, Any]], + scope_key: str | None, + resolve_image_url: Callable[[str], Awaitable[str | None]] | None = None, +) -> RegisteredMessageAttachments: + attachments: list[dict[str, str]] = [] + normalized_parts: list[str] = [] + if registry is None or not scope_key: + for segment in segments: + type_ = str(segment.get("type", "") or "") + raw_data = segment.get("data", {}) + data = raw_data if isinstance(raw_data, Mapping) else {} + normalized_parts.append(_segment_text(type_, data, None)) + return RegisteredMessageAttachments( + attachments=[], + normalized_text="".join(normalized_parts).strip(), + ) + + for index, segment in enumerate(segments): + type_ = str(segment.get("type", "") or "").strip().lower() + raw_data = segment.get("data", {}) + data = raw_data if isinstance(raw_data, Mapping) else {} + ref: dict[str, str] | None = None + + try: + if type_ == "image": + raw_source = str(data.get("file") or data.get("url") or "").strip() + display_name = _display_name_from_source( + raw_source, + f"image_{index + 1}.png", + ) + if raw_source.startswith("base64://"): + payload = raw_source[len("base64://") :].strip() + content = base64.b64decode(payload) + record = await registry.register_bytes( + scope_key, + content, + kind="image", + display_name=display_name, + source_kind="base64_image", + source_ref=f"segment:{index}", + ) + ref = record.prompt_ref() + elif _is_data_url(raw_source): + record = await registry.register_data_url( + scope_key, + raw_source, + kind="image", + display_name=display_name, + source_kind="data_url_image", + source_ref=f"segment:{index}", + ) + ref = record.prompt_ref() + else: + resolved_source = raw_source + if raw_source and resolve_image_url is not None: + try: + resolved = await resolve_image_url(raw_source) + except Exception as exc: + logger.debug( + "[AttachmentRegistry] image resolver failed: file=%s err=%s", + raw_source, + exc, + ) + resolved = None + if resolved: + resolved_source = str(resolved) + + if _is_http_url(resolved_source): + record = await registry.register_remote_url( + scope_key, + resolved_source, + kind="image", + display_name=display_name, + source_kind="remote_image", + source_ref=raw_source or resolved_source, + ) + ref = record.prompt_ref() + elif _is_localish_path(resolved_source): + local_path = ( + resolved_source[7:] + if resolved_source.startswith("file://") + else resolved_source + ) + record = await registry.register_local_file( + scope_key, + local_path, + kind="image", + display_name=display_name, + source_kind="local_image", + source_ref=raw_source or resolved_source, + ) + ref = record.prompt_ref() + + elif type_ == "file": + file_id = str(data.get("id", "") or "").strip() + raw_source = str(data.get("file") or data.get("url") or "").strip() + local_file_path: Path | None = None + if file_id: + local_file_path = _resolve_webui_file_id(file_id) + elif _is_localish_path(raw_source): + local_file_path = Path( + raw_source[7:] + if raw_source.startswith("file://") + else raw_source + ) + display_name = ( + str(data.get("name", "") or "").strip() + or (local_file_path.name if local_file_path is not None else "") + or _display_name_from_source(raw_source, f"file_{index + 1}.bin") + ) + if local_file_path is not None and local_file_path.is_file(): + record = await registry.register_local_file( + scope_key, + local_file_path, + kind="file", + display_name=display_name, + source_kind="webui_file" if file_id else "local_file", + source_ref=file_id or raw_source or str(local_file_path), + ) + ref = record.prompt_ref() + elif _is_http_url(raw_source): + record = await registry.register_remote_url( + scope_key, + raw_source, + kind="file", + display_name=display_name, + source_kind="remote_file", + source_ref=file_id or raw_source, + ) + ref = record.prompt_ref() + except (binascii.Error, ValueError, FileNotFoundError, httpx.HTTPError) as exc: + logger.warning( + "[AttachmentRegistry] segment registration skipped: type=%s index=%s err=%s", + type_, + index, + exc, + ) + except Exception as exc: + logger.exception( + "[AttachmentRegistry] unexpected segment registration failure: type=%s index=%s err=%s", + type_, + index, + exc, + ) + + if ref is not None: + attachments.append(ref) + normalized_parts.append(_segment_text(type_, data, ref)) + + return RegisteredMessageAttachments( + attachments=attachments, + normalized_text="".join(normalized_parts).strip(), + ) + + +async def render_message_with_pic_placeholders( + message: str, + *, + registry: AttachmentRegistry | None, + scope_key: str | None, + strict: bool, +) -> RenderedRichMessage: + if ( + not message + or registry is None + or not scope_key + or ":{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + continue + + image_source = record.source_ref + if record.local_path: + image_source = Path(record.local_path).resolve().as_uri() + elif not image_source: + replacement = f"[图片 uid={uid} 缺少文件]" + if strict: + raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + continue + + delivery_parts.append(f"[CQ:image,file={image_source}]") + if record.display_name: + history_parts.append(f"[图片 uid={uid} name={record.display_name}]") + else: + history_parts.append(f"[图片 uid={uid}]") + attachments.append(record.prompt_ref()) + + delivery_parts.append(message[last_index:]) + history_parts.append(message[last_index:]) + return RenderedRichMessage( + delivery_text="".join(delivery_parts), + history_text="".join(history_parts), + attachments=attachments, + ) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 97c28a56..fb3f61c1 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -8,6 +8,11 @@ import random from typing import Any, Coroutine +from Undefined.attachments import ( + append_attachment_text, + build_attachment_scope, + register_message_attachments, +) from Undefined.ai import AIClient from Undefined.config import Config from Undefined.faq import FAQStorage @@ -117,6 +122,37 @@ def __init__( # 启动队列 self.ai_coordinator.queue_manager.start(self.ai_coordinator.execute_reply) + async def _collect_message_attachments( + self, + message_content: list[dict[str, Any]], + *, + group_id: int | None = None, + user_id: int | None = None, + request_type: str, + ) -> list[dict[str, str]]: + scope_key = build_attachment_scope( + group_id=group_id, + user_id=user_id, + request_type=request_type, + ) + if not scope_key: + return [] + ai_client = getattr(self, "ai", None) + attachment_registry = ( + getattr(ai_client, "attachment_registry", None) if ai_client else None + ) + if attachment_registry is None: + return [] + onebot = getattr(self, "onebot", None) + resolve_image_url = getattr(onebot, "get_image", None) if onebot else None + result = await register_message_attachments( + registry=attachment_registry, + segments=message_content, + scope_key=scope_key, + resolve_image_url=resolve_image_url, + ) + return result.attachments + async def handle_message(self, event: dict[str, Any]) -> None: """处理收到的消息事件""" if logger.isEnabledFor(logging.DEBUG): @@ -247,6 +283,11 @@ async def handle_message(self, event: dict[str, Any]) -> None: logger.warning("获取用户昵称失败: %s", exc) text = extract_text(private_message_content, self.config.bot_qq) + private_attachments = await self._collect_message_attachments( + private_message_content, + user_id=private_sender_id, + request_type="private", + ) safe_text = redact_string(text) logger.info( "[私聊消息] 发送者=%s 昵称=%s 内容=%s", @@ -263,6 +304,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: self.onebot.get_msg, self.onebot.get_forward_msg, ) + parsed_content = append_attachment_text(parsed_content, private_attachments) safe_parsed = redact_string(parsed_content) logger.debug( "[历史记录] 保存私聊: user=%s content=%s...", @@ -275,6 +317,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: display_name=private_sender_nickname, user_name=user_name, message_id=trigger_message_id, + attachments=private_attachments, ) # 如果是 bot 自己的消息,只保存不触发回复,避免无限循环 @@ -337,6 +380,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: private_sender_id, text, private_message_content, + attachments=private_attachments, sender_name=user_name, trigger_message_id=trigger_message_id, ) @@ -372,6 +416,11 @@ async def handle_message(self, event: dict[str, Any]) -> None: # 提取文本内容 text = extract_text(message_content, self.config.bot_qq) + group_attachments = await self._collect_message_attachments( + message_content, + group_id=group_id, + request_type="group", + ) safe_text = redact_string(text) logger.info( f"[群消息] group={group_id} sender={sender_id} name={sender_card or sender_nickname} " @@ -395,6 +444,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: self.onebot.get_msg, self.onebot.get_forward_msg, ) + parsed_content = append_attachment_text(parsed_content, group_attachments) safe_parsed = redact_string(parsed_content) logger.debug( f"[历史记录] 保存群聊: group={group_id}, sender={sender_id}, content={safe_parsed[:50]}..." @@ -409,6 +459,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: role=sender_role, title=sender_title, message_id=trigger_message_id, + attachments=group_attachments, ) # 如果是 bot 自己的消息,只保存不触发回复,避免无限循环 @@ -511,6 +562,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: sender_id, text, message_content, + attachments=group_attachments, sender_name=display_name, group_name=group_name, sender_role=sender_role, diff --git a/src/Undefined/services/ai_coordinator.py b/src/Undefined/services/ai_coordinator.py index 61f53afe..c73ba903 100644 --- a/src/Undefined/services/ai_coordinator.py +++ b/src/Undefined/services/ai_coordinator.py @@ -3,6 +3,11 @@ from pathlib import Path from typing import Any, Optional +from Undefined.attachments import ( + attachment_refs_to_xml, + build_attachment_scope, + render_message_with_pic_placeholders, +) from Undefined.config import Config from Undefined.context import RequestContext from Undefined.context_resource_registry import collect_context_resources @@ -61,6 +66,7 @@ async def handle_auto_reply( sender_id: int, text: str, message_content: list[dict[str, Any]], + attachments: list[dict[str, str]] | None = None, is_poke: bool = False, sender_name: str = "未知用户", group_name: str = "未知群聊", @@ -122,7 +128,8 @@ async def handle_auto_reply( sender_title, current_time, text, - trigger_message_id, + attachments=attachments, + message_id=trigger_message_id, ) logger.debug( "[自动回复] full_question_len=%s group=%s sender=%s", @@ -164,6 +171,7 @@ async def handle_private_reply( user_id: int, text: str, message_content: list[dict[str, Any]], + attachments: list[dict[str, str]] | None = None, is_poke: bool = False, sender_name: str = "未知用户", trigger_message_id: int | None = None, @@ -184,8 +192,12 @@ async def handle_private_reply( message_id_attr = "" if trigger_message_id is not None: message_id_attr = f' message_id="{escape_xml_attr(trigger_message_id)}"' + attachment_xml = ( + f"\n{attachment_refs_to_xml(attachments)}" if attachments else "" + ) full_question = f"""{prompt_prefix} {escape_xml_text(text)} +{attachment_xml} 【私聊消息】 @@ -440,7 +452,21 @@ async def send_private_cb( }, ) if result: - await self.sender.send_private_message(user_id, result) + scope_key = build_attachment_scope( + user_id=user_id, + request_type="private", + ) + rendered = await render_message_with_pic_placeholders( + str(result), + registry=self.ai.attachment_registry, + scope_key=scope_key, + strict=False, + ) + await self.sender.send_private_message( + user_id, + rendered.delivery_text, + history_message=rendered.history_text, + ) except Exception: logger.exception("私聊回复执行出错") raise @@ -674,6 +700,7 @@ def _build_prompt( title: str, time_str: str, text: str, + attachments: list[dict[str, str]] | None = None, message_id: int | None = None, ) -> str: """构建最终发送给 AI 的结构化 XML 消息 Prompt @@ -692,8 +719,11 @@ def _build_prompt( message_id_attr = "" if message_id is not None: message_id_attr = f' message_id="{escape_xml_attr(message_id)}"' + attachment_xml = ( + f"\n{attachment_refs_to_xml(attachments)}" if attachments else "" + ) return f"""{prefix} - {safe_text} + {safe_text}{attachment_xml} 【回复策略 - 极低频参与】 diff --git a/src/Undefined/skills/agents/entertainment_agent/prompt.md b/src/Undefined/skills/agents/entertainment_agent/prompt.md index f1be2f90..cff647b3 100644 --- a/src/Undefined/skills/agents/entertainment_agent/prompt.md +++ b/src/Undefined/skills/agents/entertainment_agent/prompt.md @@ -5,6 +5,8 @@ - 适当给出可选项,让用户选择方向。 - 用户明确要“随机视频/刷个视频”时,优先调用视频推荐工具。 - 输出轻松友好,但不要过度承诺或编造。 +- 如果工具返回了图片 UID,且用户需要图文并茂的结果,优先在最终回复里用 `` 做图文混排,而不是单独口头描述“我发图了”。 +- `` 只能引用当前会话里真实存在的图片 UID,不能臆造,也不要改写成 Markdown 图片语法。 边界提醒: - 正经资讯或需要核验的信息,引导至 info_agent / web_agent。 diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json index c309fa93..e6d8b525 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json @@ -27,21 +27,26 @@ "description": "图片响应格式(仅 models 模式生效)", "enum": ["url", "b64_json", "base64"] }, + "delivery": { + "type": "string", + "description": "图片交付方式:embed 返回可插入回复的图片 UID;send 立即发送到当前或显式指定目标", + "enum": ["embed", "send"] + }, "n": { "type": "integer", "description": "生成图片数量(仅 models 模式生效,1 到 10)" }, "target_id": { "type": "integer", - "description": "发送目标的 ID" + "description": "发送目标的 ID(delivery=send 时可显式提供;不提供则尝试从当前会话推断)" }, "message_type": { "type": "string", - "description": "消息类型 (group 或 private)", + "description": "消息类型 (group 或 private,delivery=send 时可显式提供;不提供则尝试从当前会话推断)", "enum": ["group", "private"] } }, - "required": ["prompt", "target_id", "message_type"] + "required": ["prompt"] } } } diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py index 6215a46a..56d7683c 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py @@ -17,6 +17,7 @@ import httpx +from Undefined.attachments import scope_from_context from Undefined.skills.http_client import request_with_retry from Undefined.skills.http_config import get_request_timeout, get_xingzhige_url @@ -390,6 +391,74 @@ async def _save_and_send( return await _send_cached_image(filepath, target_id, message_type, context) +def _resolve_send_target( + target_id: int | str | None, + message_type: str | None, + context: dict[str, Any], +) -> tuple[int | str | None, str | None, str | None]: + if target_id is not None and message_type is not None: + return target_id, message_type, None + + request_type = str(context.get("request_type", "") or "").strip().lower() + if request_type == "group": + resolved_group_id = context.get("group_id") + if resolved_group_id is not None: + return resolved_group_id, "group", None + if request_type == "private": + resolved_user_id = context.get("user_id") + if resolved_user_id is not None: + return resolved_user_id, "private", None + + return None, None, "图片生成成功,但缺少发送目标参数" + + +async def _register_generated_image( + generated_image: _GeneratedImagePayload, + context: dict[str, Any], +) -> tuple[Any | None, str | None]: + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + if attachment_registry is None or not scope_key: + return None, "当前会话未提供附件注册能力,无法生成可嵌入图片 UID" + + display_name = f"ai_draw_{uuid.uuid4().hex[:8]}.png" + if generated_image.image_bytes is not None: + record = await attachment_registry.register_bytes( + scope_key, + generated_image.image_bytes, + kind="image", + display_name=display_name, + source_kind="generated_image", + source_ref="ai_draw_one", + ) + return record, None + + if generated_image.image_url: + record = await attachment_registry.register_remote_url( + scope_key, + generated_image.image_url, + kind="image", + display_name=display_name, + source_kind="generated_image_url", + source_ref=generated_image.image_url, + ) + return record, None + + return None, "图片生成失败:未找到可保存的图片内容" + + +async def _send_registered_record( + record: Any, + target_id: int | str, + message_type: str, + context: dict[str, Any], +) -> str: + local_path = str(getattr(record, "local_path", "") or "").strip() + if not local_path: + return "图片生成失败:已生成图片,但本地缓存不可用" + return await _send_cached_image(local_path, target_id, message_type, context) + + async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: """执行 AI 绘图""" from Undefined.config import get_config @@ -400,6 +469,7 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: style_arg: str | None = args.get("style") response_format_arg: str | None = args.get("response_format") n_arg = args.get("n") + delivery = str(args.get("delivery", "embed") or "embed").strip().lower() target_id: int | str | None = args.get("target_id") message_type_arg: str | None = args.get("message_type") @@ -414,6 +484,9 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: generated_result: str | _GeneratedImagePayload try: + if delivery not in {"embed", "send"}: + return f"delivery 无效:{delivery}。仅支持 embed 或 send" + if provider == "xingzhige": prompt = prompt_arg or "" size = size_arg or cfg.xingzhige_size @@ -468,21 +541,49 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: return generated_result generated_image = _GeneratedImagePayload(image_url=generated_result) - if target_id is None or message_type_arg is None: + registered_record, register_error = await _register_generated_image( + generated_image, + context, + ) + if delivery == "embed": + if register_error or registered_record is None: + return register_error or "图片生成失败:无法创建内嵌图片 UID" + success = True + uid = str(getattr(registered_record, "uid", "") or "").strip() + return f'已生成图片,可在回复中插入 ' + + resolved_target_id, resolved_message_type, target_error = _resolve_send_target( + target_id, + message_type_arg, + context, + ) + if target_error: + return target_error + if resolved_target_id is None or resolved_message_type is None: return "图片生成成功,但缺少发送目标参数" send_timeout = get_request_timeout(60.0) - if generated_image.image_url: + if registered_record is not None: + result = await _send_registered_record( + registered_record, + resolved_target_id, + resolved_message_type, + context, + ) + elif generated_image.image_url: result = await _download_and_send( generated_image.image_url, - target_id, - message_type_arg, + resolved_target_id, + resolved_message_type, send_timeout, context, ) elif generated_image.image_bytes is not None: result = await _save_and_send( - generated_image.image_bytes, target_id, message_type_arg, context + generated_image.image_bytes, + resolved_target_id, + resolved_message_type, + context, ) else: return "图片生成失败:未找到可发送的图片内容" diff --git a/src/Undefined/skills/agents/file_analysis_agent/config.json b/src/Undefined/skills/agents/file_analysis_agent/config.json index 4e248b3e..6d026f7e 100644 --- a/src/Undefined/skills/agents/file_analysis_agent/config.json +++ b/src/Undefined/skills/agents/file_analysis_agent/config.json @@ -2,13 +2,13 @@ "type": "function", "function": { "name": "file_analysis_agent", - "description": "文件分析助手,支持解析各种文件格式:文档(PDF、Word、PPT、Excel)、代码、压缩包、图片、音频、视频等。参数为 URL 或 file_id(QQ 内的文件 ID)。", + "description": "文件分析助手,支持解析各种文件格式:文档(PDF、Word、PPT、Excel)、代码、压缩包、图片、音频、视频等。参数优先使用内部附件 UID,也兼容 URL 或 legacy file_id。", "parameters": { "type": "object", "properties": { "file_source": { "type": "string", - "description": "文件源,可以是 URL 或 QQ 的 file_id" + "description": "文件源。优先传内部附件 UID(例如 pic_xxx / file_xxx);也兼容 URL 或 QQ 的 legacy file_id" }, "prompt": { "type": "string", diff --git a/src/Undefined/skills/agents/file_analysis_agent/handler.py b/src/Undefined/skills/agents/file_analysis_agent/handler.py index 3f8bb909..7db9cbe0 100644 --- a/src/Undefined/skills/agents/file_analysis_agent/handler.py +++ b/src/Undefined/skills/agents/file_analysis_agent/handler.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any +from Undefined.attachments import scope_from_context from Undefined.skills.agents.runner import run_agent_with_tools logger = logging.getLogger(__name__) @@ -19,10 +20,28 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: return "请提供文件 URL 或 file_id" context["file_source"] = file_source + attachment_registry = context.get("attachment_registry") + resolved_record = None + if attachment_registry is not None: + resolved_record = attachment_registry.resolve( + file_source, scope_from_context(context) + ) + + source_description = file_source + if resolved_record is not None: + display_name = str(getattr(resolved_record, "display_name", "") or "").strip() + media_type = str(getattr(resolved_record, "media_type", "") or "").strip() + source_description = f"{file_source}(内部附件 UID" + if display_name: + source_description += f",{display_name}" + if media_type: + source_description += f",{media_type}" + source_description += ")" + context_messages = [ { "role": "system", - "content": f"当前任务附带文件源:{file_source}", + "content": f"当前任务附带文件源:{source_description}", } ] user_content = user_prompt if user_prompt else "请分析这个文件。" diff --git a/src/Undefined/skills/agents/file_analysis_agent/prompt.md b/src/Undefined/skills/agents/file_analysis_agent/prompt.md index 7ad81bd4..1bdf69e9 100644 --- a/src/Undefined/skills/agents/file_analysis_agent/prompt.md +++ b/src/Undefined/skills/agents/file_analysis_agent/prompt.md @@ -10,6 +10,11 @@ - 收到历史记录后,**优先根据已有的描述和提取内容来回答用户问题**,避免不必要的重复分析。 - 仅当历史记录确实无法覆盖当前需求时(如需要关注之前未涉及的细节),才再次调用并设置 `force_analyze=true` 进行全新分析。 +附件输入规则: +- 用户上下文里如果给了内部附件 UID(如 `pic_xxx` / `file_xxx`),优先直接使用这个 UID,不要先去猜 URL。 +- 只有在没有内部 UID 时,才回退到显式 URL 或 legacy `file_id`。 +- 不要臆造 UID;只能使用当前上下文明确给出的附件标识。 + 工作原则: - 先明确需要从文件中「识别或提取」什么(内容识别/摘要/提取/统计/结构),再选择工具。 - 不确定格式时可先尝试文本读取,再决定是否走专用解析器。 diff --git a/src/Undefined/skills/agents/file_analysis_agent/tools/download_file/config.json b/src/Undefined/skills/agents/file_analysis_agent/tools/download_file/config.json index ef359f89..d8dff0df 100644 --- a/src/Undefined/skills/agents/file_analysis_agent/tools/download_file/config.json +++ b/src/Undefined/skills/agents/file_analysis_agent/tools/download_file/config.json @@ -2,13 +2,13 @@ "type": "function", "function": { "name": "download_file", - "description": "下载文件到临时目录。支持 URL 或 QQ 的 file_id。对于 URL 会先尝试获取文件大小,获取失败则拒绝下载。", + "description": "下载文件到临时目录。优先支持内部附件 UID,也兼容 URL 或 QQ 的 legacy file_id。对于 URL 会先尝试获取文件大小,获取失败则拒绝下载。", "parameters": { "type": "object", "properties": { "file_source": { "type": "string", - "description": "文件源:URL 或 QQ 的 file_id" + "description": "文件源:优先传内部附件 UID,也兼容 URL 或 QQ 的 legacy file_id" }, "max_size_mb": { "type": "number", diff --git a/src/Undefined/skills/agents/file_analysis_agent/tools/download_file/handler.py b/src/Undefined/skills/agents/file_analysis_agent/tools/download_file/handler.py index 8729d76f..f12b6654 100644 --- a/src/Undefined/skills/agents/file_analysis_agent/tools/download_file/handler.py +++ b/src/Undefined/skills/agents/file_analysis_agent/tools/download_file/handler.py @@ -5,6 +5,8 @@ import httpx import aiofiles +from Undefined.attachments import scope_from_context + logger = logging.getLogger(__name__) SIZE_LIMITS = { @@ -44,6 +46,21 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: temp_dir: Path = ensure_dir(DOWNLOAD_CACHE_DIR / task_uuid) + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + if attachment_registry and scope_key: + try: + record = attachment_registry.resolve(file_source, scope_key) + except Exception: + record = None + if record is not None: + return await _download_from_attachment_record( + record, + temp_dir=temp_dir, + max_size_mb=max_size_mb, + task_uuid=task_uuid, + ) + is_url: bool = file_source.startswith("http://") or file_source.startswith( "https://" ) @@ -156,3 +173,49 @@ def _extract_filename_from_url(url: str) -> str: url = url.split("?")[0] filename = url.split("/")[-1] return filename + + +async def _download_from_attachment_record( + record: Any, + *, + temp_dir: Path, + max_size_mb: float, + task_uuid: str, +) -> str: + max_size_bytes: int = int(max_size_mb * 1024 * 1024) + local_path_raw = getattr(record, "local_path", None) + if local_path_raw: + local_path = Path(str(local_path_raw)) + if local_path.is_file(): + size = local_path.stat().st_size + if size > max_size_bytes: + return ( + f"错误:文件大小 ({size / 1024 / 1024:.2f}MB) " + f"超过限制 ({max_size_mb}MB)" + ) + display_name = str(getattr(record, "display_name", "") or "").strip() + filename = display_name or local_path.name or f"downloaded_{task_uuid}" + target = temp_dir / filename + async with aiofiles.open(local_path, "rb") as src: + content = await src.read() + target.write_bytes(content) + logger.info("附件 UID 已复制到: %s", target) + return str(target) + + source_ref = str(getattr(record, "source_ref", "") or "").strip() + if source_ref.startswith("http://") or source_ref.startswith("https://"): + return await _download_from_url(source_ref, temp_dir, max_size_mb, task_uuid) + + if source_ref: + candidate = Path(source_ref) + if candidate.exists() and candidate.is_file(): + display_name = str(getattr(record, "display_name", "") or "").strip() + filename = display_name or candidate.name or f"downloaded_{task_uuid}" + target = temp_dir / filename + async with aiofiles.open(candidate, "rb") as src: + content = await src.read() + target.write_bytes(content) + logger.info("附件 UID 源文件已复制到: %s", target) + return str(target) + + return f"错误:无法从附件 UID {getattr(record, 'uid', '')} 解析到可下载文件" diff --git a/src/Undefined/skills/toolsets/messages/send_message/handler.py b/src/Undefined/skills/toolsets/messages/send_message/handler.py index 2d115a6c..38f7dcfa 100644 --- a/src/Undefined/skills/toolsets/messages/send_message/handler.py +++ b/src/Undefined/skills/toolsets/messages/send_message/handler.py @@ -1,6 +1,11 @@ from typing import Any, Dict, Literal import logging +from Undefined.attachments import ( + render_message_with_pic_placeholders, + scope_from_context, +) + logger = logging.getLogger(__name__) @@ -198,6 +203,23 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.warning("[发送消息] 收到空消息请求") return "消息内容不能为空" + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + try: + rendered = await render_message_with_pic_placeholders( + message, + registry=attachment_registry, + scope_key=scope_key, + strict=True, + ) + except Exception as exc: + logger.warning( + "[发送消息] 图片内嵌渲染失败: request_id=%s err=%s", request_id, exc + ) + return f"发送失败:{exc}" + message = rendered.delivery_text + history_message = rendered.history_text + # 解析 reply_to 参数(无效值静默忽略,视为未传) reply_to_id, _ = _parse_positive_int(args.get("reply_to"), "reply_to") @@ -235,7 +257,10 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if target_type == "group": logger.info("[发送消息] 准备发送到群 %s: %s", target_id, message[:100]) sent_message_id = await sender.send_group_message( - target_id, message, reply_to=reply_to_id + target_id, + message, + reply_to=reply_to_id, + history_message=history_message, ) else: logger.info("[发送消息] 准备发送私聊 %s: %s", target_id, message[:100]) @@ -244,6 +269,7 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: message, reply_to=reply_to_id, preferred_temp_group_id=_get_context_group_id(context), + history_message=history_message, ) context["message_sent_this_turn"] = True return _format_send_success(sent_message_id) diff --git a/src/Undefined/skills/toolsets/messages/send_private_message/handler.py b/src/Undefined/skills/toolsets/messages/send_private_message/handler.py index 3823bf22..300089c8 100644 --- a/src/Undefined/skills/toolsets/messages/send_private_message/handler.py +++ b/src/Undefined/skills/toolsets/messages/send_private_message/handler.py @@ -1,6 +1,11 @@ from typing import Any, Dict import logging +from Undefined.attachments import ( + render_message_with_pic_placeholders, + scope_from_context, +) + logger = logging.getLogger(__name__) @@ -70,6 +75,23 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if not message: return "消息内容不能为空" + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + try: + rendered = await render_message_with_pic_placeholders( + message, + registry=attachment_registry, + scope_key=scope_key, + strict=True, + ) + except Exception as exc: + logger.warning( + "[私聊发送] 图片内嵌渲染失败: request_id=%s err=%s", request_id, exc + ) + return f"发送失败:{exc}" + message = rendered.delivery_text + history_message = rendered.history_text + runtime_config = context.get("runtime_config") if runtime_config is not None: if not runtime_config.is_private_allowed(user_id): @@ -81,7 +103,10 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if sender: try: sent_message_id = await sender.send_private_message( - user_id, message, reply_to=reply_to_id + user_id, + message, + reply_to=reply_to_id, + history_message=history_message, ) context["message_sent_this_turn"] = True return _format_send_success(user_id, sent_message_id) diff --git a/src/Undefined/utils/history.py b/src/Undefined/utils/history.py index 4f79a107..71ac2b43 100644 --- a/src/Undefined/utils/history.py +++ b/src/Undefined/utils/history.py @@ -168,6 +168,36 @@ async def _load_history_from_file(self, path: str) -> list[dict[str, Any]]: msg["timestamp"] = "" if "message" not in msg or msg.get("message") is None: msg["message"] = str(msg.get("content", "")) + attachments = msg.get("attachments") + if not isinstance(attachments, list): + msg["attachments"] = [] + else: + normalized_attachments: list[dict[str, str]] = [] + for item in attachments: + if not isinstance(item, dict): + continue + uid = str(item.get("uid", "") or "").strip() + if not uid: + continue + normalized_attachments.append( + { + "uid": uid, + "kind": str( + item.get("kind") + or item.get("media_type") + or "file" + ), + "media_type": str( + item.get("media_type") + or item.get("kind") + or "file" + ), + "display_name": str( + item.get("display_name", "") or "" + ), + } + ) + msg["attachments"] = normalized_attachments normalized_history.append(msg) @@ -268,6 +298,7 @@ async def add_group_message( role: str = "member", title: str = "", message_id: int | None = None, + attachments: list[dict[str, str]] | None = None, ) -> None: """异步保存群消息到历史记录""" await self._ensure_initialized() @@ -299,6 +330,8 @@ async def add_group_message( } if message_id is not None: record["message_id"] = message_id + if attachments: + record["attachments"] = attachments self._message_history[group_id_str].append(record) @@ -319,6 +352,7 @@ async def add_private_message( display_name: str = "", user_name: str = "", message_id: int | None = None, + attachments: list[dict[str, str]] | None = None, ) -> None: """异步保存私聊消息到历史记录""" await self._ensure_initialized() @@ -345,6 +379,8 @@ async def add_private_message( } if message_id is not None: record["message_id"] = message_id + if attachments: + record["attachments"] = attachments self._private_message_history[user_id_str].append(record) diff --git a/src/Undefined/utils/paths.py b/src/Undefined/utils/paths.py index d3626544..6d824f6b 100644 --- a/src/Undefined/utils/paths.py +++ b/src/Undefined/utils/paths.py @@ -6,10 +6,12 @@ CACHE_DIR = DATA_DIR / "cache" RENDER_CACHE_DIR = CACHE_DIR / "render" IMAGE_CACHE_DIR = CACHE_DIR / "images" +ATTACHMENT_CACHE_DIR = CACHE_DIR / "attachments" DOWNLOAD_CACHE_DIR = CACHE_DIR / "downloads" TEXT_FILE_CACHE_DIR = CACHE_DIR / "text_files" URL_FILE_CACHE_DIR = CACHE_DIR / "url_files" WEBUI_FILE_CACHE_DIR = CACHE_DIR / "webui_files" +ATTACHMENT_REGISTRY_FILE = DATA_DIR / "attachment_registry.json" def ensure_dir(path: Path) -> Path: diff --git a/src/Undefined/utils/sender.py b/src/Undefined/utils/sender.py index 3ee06594..471f4948 100644 --- a/src/Undefined/utils/sender.py +++ b/src/Undefined/utils/sender.py @@ -83,6 +83,7 @@ async def send_group_message( *, mark_sent: bool = True, reply_to: int | None = None, + history_message: str | None = None, ) -> int | None: """发送群消息""" if not self.config.is_group_allowed(group_id): @@ -108,8 +109,11 @@ async def send_group_message( # 准备历史记录文本(不含 reply 段) history_content: str | None = None if auto_history: - hist_segments = message_to_segments(message) - history_content = extract_text(hist_segments, self.bot_qq) + if history_message is not None: + history_content = history_message + else: + hist_segments = message_to_segments(message) + history_content = extract_text(hist_segments, self.bot_qq) if history_prefix: history_content = f"{history_prefix}{history_content}" @@ -203,6 +207,7 @@ async def send_private_message( mark_sent: bool = True, reply_to: int | None = None, preferred_temp_group_id: int | None = None, + history_message: str | None = None, ) -> int | None: """发送私聊消息""" if not self.config.is_private_allowed(user_id): @@ -225,8 +230,11 @@ async def send_private_message( # 准备历史记录文本 history_content: str | None = None if auto_history: - hist_segments = message_to_segments(message) - history_content = extract_text(hist_segments, self.bot_qq) + if history_message is not None: + history_content = history_message + else: + hist_segments = message_to_segments(message) + history_content = extract_text(hist_segments, self.bot_qq) # 发送消息 bot_message_id: int | None = None diff --git a/src/Undefined/webui/static/js/runtime.js b/src/Undefined/webui/static/js/runtime.js index f47c39d8..6968c0fa 100644 --- a/src/Undefined/webui/static/js/runtime.js +++ b/src/Undefined/webui/static/js/runtime.js @@ -430,6 +430,12 @@ const payload = raw.slice("base64://".length).trim(); return payload ? `data:image/png;base64,${payload}` : ""; } + if (raw.startsWith("file://")) { + const localPath = raw.slice("file://".length).trim(); + return localPath + ? `/api/runtime/chat/image?path=${encodeURIComponent(localPath)}` + : ""; + } if (raw.startsWith("/") || /^[A-Za-z]:[\\/]/.test(raw)) { return `/api/runtime/chat/image?path=${encodeURIComponent(raw)}`; } diff --git a/tests/test_ai_draw_one_handler.py b/tests/test_ai_draw_one_handler.py index d8c09262..539e9070 100644 --- a/tests/test_ai_draw_one_handler.py +++ b/tests/test_ai_draw_one_handler.py @@ -8,6 +8,7 @@ import httpx import pytest +from Undefined.attachments import AttachmentRegistry from Undefined.skills.agents.entertainment_agent.tools.ai_draw_one import ( handler as ai_draw_handler, ) @@ -105,6 +106,7 @@ async def _send_image( "prompt": "violet flowers", "size": "1024x1024", "response_format": response_key, + "delivery": "send", "target_id": 10001, "message_type": "group", }, @@ -132,6 +134,7 @@ async def test_execute_models_rejects_invalid_size( { "prompt": "violet flowers", "size": "1:1", + "delivery": "send", "target_id": 10001, "message_type": "group", }, @@ -185,6 +188,7 @@ async def _send_image( { "prompt": "violet flowers", "size": "1024x1024", + "delivery": "send", "target_id": 10001, "message_type": "group", }, @@ -225,6 +229,7 @@ async def _fake_request_with_retry(*_args: Any, **_kwargs: Any) -> Any: { "prompt": "violet flowers", "size": "1024x1024", + "delivery": "send", "target_id": 10001, "message_type": "group", }, @@ -280,6 +285,7 @@ async def _send_image( "prompt": "violet flowers", "model": "dall-e-3", "size": "1024x1024", + "delivery": "send", "target_id": 10001, "message_type": "group", }, @@ -288,3 +294,107 @@ async def _send_image( assert result == "AI 绘图已发送给 group 10001" assert seen_request["json_data"]["model"] == "grok-imagine-1.0" + + +@pytest.mark.asyncio +async def test_execute_defaults_to_embed_and_returns_pic_uid( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + monkeypatch.setattr( + "Undefined.config.get_config", + lambda strict=False: _make_runtime_config(request_params={}), + ) + + payload_base64 = base64.b64encode(_PNG_BYTES).decode("ascii") + + class _FakeResponse: + text = "" + + def json(self) -> dict[str, Any]: + return {"data": [{"base64": payload_base64}]} + + async def _fake_request_with_retry(*_args: Any, **_kwargs: Any) -> _FakeResponse: + return _FakeResponse() + + monkeypatch.setattr(ai_draw_handler, "request_with_retry", _fake_request_with_retry) + + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + result = await ai_draw_handler.execute( + { + "prompt": "violet flowers", + "size": "1024x1024", + }, + { + "attachment_registry": registry, + "request_type": "group", + "group_id": 10001, + }, + ) + + assert result.startswith('已生成图片,可在回复中插入 None: + monkeypatch.setattr( + "Undefined.config.get_config", + lambda strict=False: _make_runtime_config(request_params={}), + ) + + payload_base64 = base64.b64encode(_PNG_BYTES).decode("ascii") + + class _FakeResponse: + text = "" + + def json(self) -> dict[str, Any]: + return {"data": [{"base64": payload_base64}]} + + async def _fake_request_with_retry(*_args: Any, **_kwargs: Any) -> _FakeResponse: + return _FakeResponse() + + sent: dict[str, Any] = {} + + async def _send_image( + target_id: int | str, + message_type: str, + file_path: str, + ) -> None: + sent["target_id"] = target_id + sent["message_type"] = message_type + sent["file_path"] = file_path + + monkeypatch.setattr(ai_draw_handler, "request_with_retry", _fake_request_with_retry) + + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + result = await ai_draw_handler.execute( + { + "prompt": "violet flowers", + "size": "1024x1024", + "delivery": "send", + }, + { + "attachment_registry": registry, + "request_type": "group", + "group_id": 10001, + "send_image_callback": _send_image, + }, + ) + + assert result == "AI 绘图已发送给 group 10001" + assert sent["target_id"] == 10001 + assert sent["message_type"] == "group" + assert Path(sent["file_path"]).read_bytes() == _PNG_BYTES diff --git a/tests/test_attachments.py b/tests/test_attachments.py new file mode 100644 index 00000000..1e927aa1 --- /dev/null +++ b/tests/test_attachments.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import base64 +from pathlib import Path + +import pytest + +from Undefined.attachments import ( + AttachmentRegistry, + register_message_attachments, + render_message_with_pic_placeholders, +) + + +_PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n" + b"\x00\x00\x00\rIHDR" + b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00" + b"\x90wS\xde" + b"\x00\x00\x00\x0cIDATx\x9cc``\x00\x00\x00\x02\x00\x01" + b"\x0b\xe7\x02\x9d" + b"\x00\x00\x00\x00IEND\xaeB`\x82" +) + + +@pytest.mark.asyncio +async def test_attachment_registry_persists_and_respects_scope( + tmp_path: Path, +) -> None: + registry_path = tmp_path / "attachment_registry.json" + cache_dir = tmp_path / "attachments" + registry = AttachmentRegistry(registry_path=registry_path, cache_dir=cache_dir) + + record = await registry.register_bytes( + "group:10001", + _PNG_BYTES, + kind="image", + display_name="cat.png", + source_kind="test", + ) + + reloaded = AttachmentRegistry(registry_path=registry_path, cache_dir=cache_dir) + assert reloaded.resolve(record.uid, "group:10001") is not None + assert reloaded.resolve(record.uid, "group:10002") is None + + +@pytest.mark.asyncio +async def test_register_message_attachments_normalizes_webui_base64_image( + tmp_path: Path, +) -> None: + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + payload = base64.b64encode(_PNG_BYTES).decode("ascii") + segments = [ + {"type": "text", "data": {"text": "我给你看"}}, + {"type": "image", "data": {"file": f"base64://{payload}"}}, + {"type": "text", "data": {"text": "这张图"}}, + ] + + result = await register_message_attachments( + registry=registry, + segments=segments, + scope_key="webui", + ) + + assert len(result.attachments) == 1 + uid = result.attachments[0]["uid"] + assert uid.startswith("pic_") + assert uid in result.normalized_text + assert "这张图" in result.normalized_text + + +@pytest.mark.asyncio +async def test_render_message_with_pic_placeholders_uses_file_uri_and_shadow_text( + tmp_path: Path, +) -> None: + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + record = await registry.register_bytes( + "group:10001", + _PNG_BYTES, + kind="image", + display_name="cat.png", + source_kind="test", + ) + + rendered = await render_message_with_pic_placeholders( + f'介绍一下\n\n如图', + registry=registry, + scope_key="group:10001", + strict=True, + ) + + assert "[CQ:image,file=file://" in rendered.delivery_text + assert f"[图片 uid={record.uid} name=cat.png]" in rendered.history_text diff --git a/tests/test_file_analysis_attachment_uid.py b/tests/test_file_analysis_attachment_uid.py new file mode 100644 index 00000000..24d41c0b --- /dev/null +++ b/tests/test_file_analysis_attachment_uid.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from Undefined.attachments import AttachmentRegistry +from Undefined.skills.agents.file_analysis_agent.tools.download_file import ( + handler as download_file_handler, +) + + +@pytest.mark.asyncio +async def test_download_file_supports_internal_attachment_uid( + tmp_path: Path, +) -> None: + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + record = await registry.register_bytes( + "private:12345", + b"hello attachment", + kind="file", + display_name="demo.txt", + source_kind="test", + ) + + result = await download_file_handler.execute( + {"file_source": record.uid}, + { + "attachment_registry": registry, + "request_type": "private", + "user_id": 12345, + }, + ) + + downloaded = Path(result) + assert downloaded.is_file() + assert downloaded.name == "demo.txt" + assert downloaded.read_bytes() == b"hello attachment" diff --git a/tests/test_send_message_tool.py b/tests/test_send_message_tool.py index ae6d19ee..12305b62 100644 --- a/tests/test_send_message_tool.py +++ b/tests/test_send_message_tool.py @@ -1,11 +1,13 @@ from __future__ import annotations +from pathlib import Path from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock import pytest +from Undefined.attachments import AttachmentRegistry from Undefined.skills.toolsets.messages.send_message.handler import execute @@ -49,6 +51,7 @@ async def test_send_message_private_passes_context_group_as_preferred_temp_group "hello", reply_to=None, preferred_temp_group_id=10001, + history_message="hello", ) sender.send_group_message.assert_not_called() assert context["message_sent_this_turn"] is True @@ -131,7 +134,10 @@ async def test_send_message_does_not_implicitly_use_trigger_message_id() -> None assert result == "消息已发送" sender.send_group_message.assert_called_once_with( - 10001, "hello without quote", reply_to=None + 10001, + "hello without quote", + reply_to=None, + history_message="hello without quote", ) @@ -158,3 +164,45 @@ async def test_send_message_returns_sent_message_id_when_available() -> None: ) assert result == "消息已发送(message_id=77777)" + + +@pytest.mark.asyncio +async def test_send_message_renders_pic_uid_before_sending(tmp_path: Path) -> None: + sender = SimpleNamespace( + send_group_message=AsyncMock(return_value=77777), + send_private_message=AsyncMock(), + ) + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + record = await registry.register_bytes( + "group:10001", + b"\x89PNG\r\n\x1a\n", + kind="image", + display_name="demo.png", + source_kind="test", + ) + context: dict[str, Any] = { + "request_type": "group", + "group_id": 10001, + "sender_id": 20002, + "request_id": "req-6", + "runtime_config": _build_runtime_config(), + "sender": sender, + "attachment_registry": registry, + } + + result = await execute( + { + "message": f'图文并茂\n\n结束', + }, + context, + ) + + assert result == "消息已发送(message_id=77777)" + sent_args = sender.send_group_message.await_args + assert "[CQ:image,file=file://" in sent_args.args[1] + assert sent_args.kwargs["history_message"] == ( + f"图文并茂\n[图片 uid={record.uid} name=demo.png]\n结束" + ) diff --git a/tests/test_send_private_message_tool.py b/tests/test_send_private_message_tool.py index 69bf71d7..0cd79d00 100644 --- a/tests/test_send_private_message_tool.py +++ b/tests/test_send_private_message_tool.py @@ -60,3 +60,9 @@ async def test_send_private_message_returns_sent_message_id_when_available() -> ) assert result == "私聊消息已发送给用户 12345(message_id=99999)" + sender.send_private_message.assert_awaited_once_with( + 12345, + "hello sender private", + reply_to=None, + history_message="hello sender private", + ) From dad29705e811e1fe292e31585d98e2a112981726 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Wed, 1 Apr 2026 17:28:00 +0800 Subject: [PATCH 07/21] feat(image-gen): add agent-based prompt moderation --- res/prompts/image_gen_moderation.txt | 16 +++ .../tools/ai_draw_one/handler.py | 97 +++++++++++++++++++ tests/test_ai_draw_one_handler.py | 43 ++++++++ 3 files changed, 156 insertions(+) create mode 100644 res/prompts/image_gen_moderation.txt diff --git a/res/prompts/image_gen_moderation.txt b/res/prompts/image_gen_moderation.txt new file mode 100644 index 00000000..d8cc6f91 --- /dev/null +++ b/res/prompts/image_gen_moderation.txt @@ -0,0 +1,16 @@ +你是图片生成审核助手,只根据“待生成图片的提示词”判断是否允许生成。 + +拦截范围包括但不限于: +- 露骨色情、性剥削、未成年人性化 +- 血腥暴力、肢解、虐杀等重口内容 +- 明显违法、极端、恐怖主义相关内容 +- 明显侵犯隐私、仇恨或其他高风险内容 + +审核原则: +- 只审核风险,不提供改写建议,不解释政策,不输出多余内容 +- 安全、普通、模糊但不明显违规的内容默认放行 +- 如果只是普通人物、风景、二次元、服饰、轻微动作描写,不要误杀 + +输出格式必须严格遵守: +- 允许:`ALLOW` +- 拒绝:`BLOCK: <不超过20字的简短中文原因>` diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py index 56d7683c..0c51bcef 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py @@ -18,8 +18,10 @@ import httpx from Undefined.attachments import scope_from_context +from Undefined.ai.parsing import extract_choices_content from Undefined.skills.http_client import request_with_retry from Undefined.skills.http_config import get_request_timeout, get_xingzhige_url +from Undefined.utils.resources import read_text_resource logger = logging.getLogger(__name__) @@ -31,6 +33,7 @@ "1024x1024", ) _ALLOWED_IMAGE_RESPONSE_FORMATS = ("url", "b64_json", "base64") +_IMAGE_GEN_MODERATION_PROMPT: str | None = None @dataclass @@ -148,6 +151,93 @@ def _format_upstream_error_message(response: httpx.Response) -> str: return message or default_message +def _get_image_gen_moderation_prompt() -> str: + global _IMAGE_GEN_MODERATION_PROMPT + if _IMAGE_GEN_MODERATION_PROMPT is not None: + return _IMAGE_GEN_MODERATION_PROMPT + try: + _IMAGE_GEN_MODERATION_PROMPT = read_text_resource( + "res/prompts/image_gen_moderation.txt" + ) + except Exception as exc: + logger.error("加载生图审核提示词失败: %s", exc) + _IMAGE_GEN_MODERATION_PROMPT = ( + "你是图片生成审核助手。" + "你只根据待生成图片的提示词判断是否允许生成。" + "如果安全则只输出 ALLOW。" + "如果应拒绝则输出 BLOCK: <简短中文原因>。" + ) + return _IMAGE_GEN_MODERATION_PROMPT + + +def _resolve_agent_model_for_moderation(context: dict[str, Any]) -> Any | None: + ai_client = context.get("ai_client") + if ai_client is None: + return None + model_config = getattr(ai_client, "agent_config", None) + if model_config is None: + return None + + runtime_config = context.get("runtime_config") + model_selector = getattr(ai_client, "model_selector", None) + if runtime_config is not None and model_selector is not None: + try: + group_id = context.get("group_id", 0) or 0 + user_id = context.get("user_id", 0) or 0 + global_enabled = bool(getattr(runtime_config, "model_pool_enabled", False)) + selected = model_selector.select_agent_config( + model_config, + group_id=group_id, + user_id=user_id, + global_enabled=global_enabled, + ) + if selected is not None: + return selected + except Exception as exc: + logger.debug("生图审核选择 agent 模型失败,回退默认 agent_config: %s", exc) + return model_config + + +async def _moderate_prompt_with_agent_model( + prompt: str, + context: dict[str, Any], +) -> str | None: + text = str(prompt or "").strip() + if not text: + return None + + ai_client = context.get("ai_client") + model_config = _resolve_agent_model_for_moderation(context) + if ai_client is None or model_config is None: + logger.debug("生图审核跳过:缺少 ai_client 或 agent 模型配置") + return None + + try: + result = await ai_client.request_model( + model_config=model_config, + messages=[ + {"role": "system", "content": _get_image_gen_moderation_prompt()}, + {"role": "user", "content": f"待审核的生图提示词:\n{text}"}, + ], + max_tokens=64, + call_type="image_gen_moderation", + ) + content = extract_choices_content(result).strip() + except Exception as exc: + logger.warning("生图审核调用失败,按允许继续: %s", exc) + return None + + upper = content.upper() + if upper.startswith("ALLOW"): + return None + if upper.startswith("BLOCK"): + reason = content.split(":", 1)[1].strip() if ":" in content else "命中敏感内容" + return f"图片生成请求被审核拦截:{reason or '命中敏感内容'}" + + logger.warning("生图审核返回了无法识别的结果,按允许继续: %s", content) + return None + + def _build_openai_models_request_body( *, prompt: str, @@ -486,6 +576,13 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: try: if delivery not in {"embed", "send"}: return f"delivery 无效:{delivery}。仅支持 embed 或 send" + moderation_error = await _moderate_prompt_with_agent_model( + prompt_arg or "", + context, + ) + if moderation_error: + logger.warning("AI 绘图请求被 agent 审核拦截: prompt=%s", prompt_arg or "") + return moderation_error if provider == "xingzhige": prompt = prompt_arg or "" diff --git a/tests/test_ai_draw_one_handler.py b/tests/test_ai_draw_one_handler.py index 539e9070..7c368bc9 100644 --- a/tests/test_ai_draw_one_handler.py +++ b/tests/test_ai_draw_one_handler.py @@ -4,6 +4,7 @@ from pathlib import Path from types import SimpleNamespace from typing import Any +from unittest.mock import AsyncMock import httpx import pytest @@ -241,6 +242,48 @@ async def _fake_request_with_retry(*_args: Any, **_kwargs: Any) -> Any: assert "Image generation blocked or no valid final image" in result +@pytest.mark.asyncio +async def test_execute_blocks_when_agent_moderation_rejects( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + "Undefined.config.get_config", + lambda strict=False: _make_runtime_config(), + ) + + async def _fail_if_called(*_args: Any, **_kwargs: Any) -> Any: + raise AssertionError("image generation request should not be sent") + + fake_ai_client = SimpleNamespace( + agent_config=SimpleNamespace(model_name="agent-model"), + request_model=AsyncMock( + return_value={ + "choices": [ + {"message": {"content": "BLOCK: 露骨色情内容"}}, + ] + } + ), + ) + + monkeypatch.setattr(ai_draw_handler, "request_with_retry", _fail_if_called) + + result = await ai_draw_handler.execute( + { + "prompt": "explicit adult scene", + "delivery": "send", + "target_id": 10001, + "message_type": "group", + }, + { + "ai_client": fake_ai_client, + "send_image_callback": lambda *_args, **_kwargs: None, + }, + ) + + assert result == "图片生成请求被审核拦截:露骨色情内容" + fake_ai_client.request_model.assert_awaited_once() + + @pytest.mark.asyncio async def test_execute_models_uses_configured_model_only( monkeypatch: pytest.MonkeyPatch, From 7cf6bef86edf26488346dc3f72ffda7ce6189599 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Wed, 1 Apr 2026 17:37:24 +0800 Subject: [PATCH 08/21] feat(multimodal): support forwarded attachment uids --- src/Undefined/api/app.py | 1 + src/Undefined/attachments.py | 314 ++++++++++++++++++++++------------- src/Undefined/handlers.py | 3 + tests/test_attachments.py | 32 ++++ 4 files changed, 236 insertions(+), 114 deletions(-) diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index 1381b542..fc913107 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -1215,6 +1215,7 @@ async def _run_webui_chat( segments=input_segments, scope_key=webui_scope_key, resolve_image_url=self._ctx.onebot.get_image, + get_forward_messages=self._ctx.onebot.get_forward_msg, ) normalized_text = registered_input.normalized_text or text await self._ctx.history_manager.add_private_message( diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py index a8f187ac..00d5e7bd 100644 --- a/src/Undefined/attachments.py +++ b/src/Undefined/attachments.py @@ -58,6 +58,7 @@ (b"GIF89a", ".gif"), (b"BM", ".bmp"), ) +_FORWARD_ATTACHMENT_MAX_DEPTH = 3 @dataclass(frozen=True) @@ -338,6 +339,37 @@ def _resolve_webui_file_id(file_id: str) -> Path | None: return None +def _extract_forward_id(data: Mapping[str, Any]) -> str: + forward_id = data.get("id") or data.get("resid") or data.get("message_id") + return str(forward_id).strip() if forward_id is not None else "" + + +def _normalize_message_segments(message: Any) -> list[Mapping[str, Any]]: + if isinstance(message, list): + normalized: list[Mapping[str, Any]] = [] + for item in message: + if isinstance(item, Mapping): + normalized.append(item) + elif isinstance(item, str): + normalized.append({"type": "text", "data": {"text": item}}) + return normalized + if isinstance(message, Mapping): + return [message] + if isinstance(message, str): + return [{"type": "text", "data": {"text": message}}] + return [] + + +def _normalize_forward_nodes(raw_nodes: Any) -> list[Mapping[str, Any]]: + if isinstance(raw_nodes, list): + return [node for node in raw_nodes if isinstance(node, Mapping)] + if isinstance(raw_nodes, Mapping): + messages = raw_nodes.get("messages") + if isinstance(messages, list): + return [node for node in messages if isinstance(node, Mapping)] + return [] + + class AttachmentRegistry: """Persistent attachment registry scoped by conversation.""" @@ -560,6 +592,8 @@ async def register_message_attachments( segments: Sequence[Mapping[str, Any]], scope_key: str | None, resolve_image_url: Callable[[str], Awaitable[str | None]] | None = None, + get_forward_messages: Callable[[str], Awaitable[list[dict[str, Any]]]] + | None = None, ) -> RegisteredMessageAttachments: attachments: list[dict[str, str]] = [] normalized_parts: list[str] = [] @@ -574,137 +608,189 @@ async def register_message_attachments( normalized_text="".join(normalized_parts).strip(), ) - for index, segment in enumerate(segments): - type_ = str(segment.get("type", "") or "").strip().lower() - raw_data = segment.get("data", {}) - data = raw_data if isinstance(raw_data, Mapping) else {} - ref: dict[str, str] | None = None + visited_forward_ids: set[str] = set() - try: - if type_ == "image": - raw_source = str(data.get("file") or data.get("url") or "").strip() - display_name = _display_name_from_source( - raw_source, - f"image_{index + 1}.png", - ) - if raw_source.startswith("base64://"): - payload = raw_source[len("base64://") :].strip() - content = base64.b64decode(payload) - record = await registry.register_bytes( - scope_key, - content, - kind="image", - display_name=display_name, - source_kind="base64_image", - source_ref=f"segment:{index}", - ) - ref = record.prompt_ref() - elif _is_data_url(raw_source): - record = await registry.register_data_url( - scope_key, + async def _collect_from_segments( + current_segments: Sequence[Mapping[str, Any]], + *, + depth: int, + prefix: str, + ) -> None: + for index, segment in enumerate(current_segments): + type_ = str(segment.get("type", "") or "").strip().lower() + raw_data = segment.get("data", {}) + data = raw_data if isinstance(raw_data, Mapping) else {} + ref: dict[str, str] | None = None + + try: + if type_ == "image": + raw_source = str(data.get("file") or data.get("url") or "").strip() + display_name = _display_name_from_source( raw_source, - kind="image", - display_name=display_name, - source_kind="data_url_image", - source_ref=f"segment:{index}", + f"image_{index + 1}.png", ) - ref = record.prompt_ref() - else: - resolved_source = raw_source - if raw_source and resolve_image_url is not None: - try: - resolved = await resolve_image_url(raw_source) - except Exception as exc: - logger.debug( - "[AttachmentRegistry] image resolver failed: file=%s err=%s", - raw_source, - exc, - ) - resolved = None - if resolved: - resolved_source = str(resolved) - - if _is_http_url(resolved_source): - record = await registry.register_remote_url( + if raw_source.startswith("base64://"): + payload = raw_source[len("base64://") :].strip() + content = base64.b64decode(payload) + record = await registry.register_bytes( scope_key, - resolved_source, + content, kind="image", display_name=display_name, - source_kind="remote_image", - source_ref=raw_source or resolved_source, + source_kind="base64_image", + source_ref=f"{prefix}segment:{index}", ) ref = record.prompt_ref() - elif _is_localish_path(resolved_source): - local_path = ( - resolved_source[7:] - if resolved_source.startswith("file://") - else resolved_source + elif _is_data_url(raw_source): + record = await registry.register_data_url( + scope_key, + raw_source, + kind="image", + display_name=display_name, + source_kind="data_url_image", + source_ref=f"{prefix}segment:{index}", ) + ref = record.prompt_ref() + else: + resolved_source = raw_source + if raw_source and resolve_image_url is not None: + try: + resolved = await resolve_image_url(raw_source) + except Exception as exc: + logger.debug( + "[AttachmentRegistry] image resolver failed: file=%s err=%s", + raw_source, + exc, + ) + resolved = None + if resolved: + resolved_source = str(resolved) + + if _is_http_url(resolved_source): + record = await registry.register_remote_url( + scope_key, + resolved_source, + kind="image", + display_name=display_name, + source_kind="remote_image", + source_ref=raw_source or resolved_source, + ) + ref = record.prompt_ref() + elif _is_localish_path(resolved_source): + local_path = ( + resolved_source[7:] + if resolved_source.startswith("file://") + else resolved_source + ) + record = await registry.register_local_file( + scope_key, + local_path, + kind="image", + display_name=display_name, + source_kind="local_image", + source_ref=raw_source or resolved_source, + ) + ref = record.prompt_ref() + + elif type_ == "file": + file_id = str(data.get("id", "") or "").strip() + raw_source = str(data.get("file") or data.get("url") or "").strip() + local_file_path: Path | None = None + if file_id: + local_file_path = _resolve_webui_file_id(file_id) + elif _is_localish_path(raw_source): + local_file_path = Path( + raw_source[7:] + if raw_source.startswith("file://") + else raw_source + ) + display_name = ( + str(data.get("name", "") or "").strip() + or (local_file_path.name if local_file_path is not None else "") + or _display_name_from_source( + raw_source, f"file_{index + 1}.bin" + ) + ) + if local_file_path is not None and local_file_path.is_file(): record = await registry.register_local_file( scope_key, - local_path, - kind="image", + local_file_path, + kind="file", + display_name=display_name, + source_kind="webui_file" if file_id else "local_file", + source_ref=file_id or raw_source or str(local_file_path), + ) + ref = record.prompt_ref() + elif _is_http_url(raw_source): + record = await registry.register_remote_url( + scope_key, + raw_source, + kind="file", display_name=display_name, - source_kind="local_image", - source_ref=raw_source or resolved_source, + source_kind="remote_file", + source_ref=file_id or raw_source, ) ref = record.prompt_ref() - elif type_ == "file": - file_id = str(data.get("id", "") or "").strip() - raw_source = str(data.get("file") or data.get("url") or "").strip() - local_file_path: Path | None = None - if file_id: - local_file_path = _resolve_webui_file_id(file_id) - elif _is_localish_path(raw_source): - local_file_path = Path( - raw_source[7:] - if raw_source.startswith("file://") - else raw_source - ) - display_name = ( - str(data.get("name", "") or "").strip() - or (local_file_path.name if local_file_path is not None else "") - or _display_name_from_source(raw_source, f"file_{index + 1}.bin") + elif ( + type_ == "forward" + and get_forward_messages is not None + and depth < _FORWARD_ATTACHMENT_MAX_DEPTH + ): + forward_id = _extract_forward_id(data) + if forward_id and forward_id not in visited_forward_ids: + visited_forward_ids.add(forward_id) + try: + nodes = _normalize_forward_nodes( + await get_forward_messages(forward_id) + ) + except Exception as exc: + logger.debug( + "[AttachmentRegistry] forward resolver failed: id=%s err=%s", + forward_id, + exc, + ) + nodes = [] + for node_index, node in enumerate(nodes): + raw_message = ( + node.get("content") + or node.get("message") + or node.get("raw_message") + ) + nested_segments = _normalize_message_segments(raw_message) + if not nested_segments: + continue + await _collect_from_segments( + nested_segments, + depth=depth + 1, + prefix=f"{prefix}forward:{forward_id}:{node_index}:", + ) + except ( + binascii.Error, + ValueError, + FileNotFoundError, + httpx.HTTPError, + ) as exc: + logger.warning( + "[AttachmentRegistry] segment registration skipped: type=%s index=%s err=%s", + type_, + index, + exc, ) - if local_file_path is not None and local_file_path.is_file(): - record = await registry.register_local_file( - scope_key, - local_file_path, - kind="file", - display_name=display_name, - source_kind="webui_file" if file_id else "local_file", - source_ref=file_id or raw_source or str(local_file_path), - ) - ref = record.prompt_ref() - elif _is_http_url(raw_source): - record = await registry.register_remote_url( - scope_key, - raw_source, - kind="file", - display_name=display_name, - source_kind="remote_file", - source_ref=file_id or raw_source, - ) - ref = record.prompt_ref() - except (binascii.Error, ValueError, FileNotFoundError, httpx.HTTPError) as exc: - logger.warning( - "[AttachmentRegistry] segment registration skipped: type=%s index=%s err=%s", - type_, - index, - exc, - ) - except Exception as exc: - logger.exception( - "[AttachmentRegistry] unexpected segment registration failure: type=%s index=%s err=%s", - type_, - index, - exc, - ) + except Exception as exc: + logger.exception( + "[AttachmentRegistry] unexpected segment registration failure: type=%s index=%s err=%s", + type_, + index, + exc, + ) + + if ref is not None: + attachments.append(ref) + if depth == 0: + normalized_parts.append(_segment_text(type_, data, ref)) - if ref is not None: - attachments.append(ref) - normalized_parts.append(_segment_text(type_, data, ref)) + await _collect_from_segments(segments, depth=0, prefix="") return RegisteredMessageAttachments( attachments=attachments, diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index fb3f61c1..decff51a 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -150,6 +150,9 @@ async def _collect_message_attachments( segments=message_content, scope_key=scope_key, resolve_image_url=resolve_image_url, + get_forward_messages=getattr(onebot, "get_forward_msg", None) + if onebot + else None, ) return result.attachments diff --git a/tests/test_attachments.py b/tests/test_attachments.py index 1e927aa1..1c7874b1 100644 --- a/tests/test_attachments.py +++ b/tests/test_attachments.py @@ -72,6 +72,38 @@ async def test_register_message_attachments_normalizes_webui_base64_image( assert "这张图" in result.normalized_text +@pytest.mark.asyncio +async def test_register_message_attachments_recurses_into_forward_images( + tmp_path: Path, +) -> None: + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + payload = base64.b64encode(_PNG_BYTES).decode("ascii") + + async def _fake_get_forward(_forward_id: str) -> list[dict[str, object]]: + return [ + { + "message": [ + {"type": "text", "data": {"text": "转发内容"}}, + {"type": "image", "data": {"file": f"base64://{payload}"}}, + ] + } + ] + + result = await register_message_attachments( + registry=registry, + segments=[{"type": "forward", "data": {"id": "forward-1"}}], + scope_key="group:10001", + get_forward_messages=_fake_get_forward, + ) + + assert result.normalized_text == "[合并转发: forward-1]" + assert len(result.attachments) == 1 + assert result.attachments[0]["uid"].startswith("pic_") + + @pytest.mark.asyncio async def test_render_message_with_pic_placeholders_uses_file_uri_and_shadow_text( tmp_path: Path, From 9102ee228175e209a68832ae6cb481269f0f8468 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Wed, 1 Apr 2026 18:01:49 +0800 Subject: [PATCH 09/21] feat(image-gen): support reference image edits --- config.toml.example | 17 + src/Undefined/config/loader.py | 34 ++ .../agents/entertainment_agent/prompt.md | 1 + .../tools/ai_draw_one/config.json | 7 + .../tools/ai_draw_one/handler.py | 301 ++++++++++++++++-- src/Undefined/skills/http_client.py | 2 + tests/test_ai_draw_one_handler.py | 111 +++++++ tests/test_config_request_params.py | 24 ++ 8 files changed, 478 insertions(+), 19 deletions(-) diff --git a/config.toml.example b/config.toml.example index 0e5b47e0..0c90af2d 100644 --- a/config.toml.example +++ b/config.toml.example @@ -527,6 +527,23 @@ model_name = "" # en: Extra request-body params (optional). [models.image_gen.request_params] +# zh: 参考图生图模型配置(用于 ai_draw_one 传入 reference_image_uids 时调用 OpenAI 兼容的图片编辑接口)。 +# en: Reference-image generation model config (used when ai_draw_one receives reference_image_uids and calls the OpenAI-compatible image editing API). +[models.image_edit] +# zh: OpenAI-compatible 基址 URL,例如 https://api.openai.com/v1(最终请求路径为 /v1/images/edits)。 +# en: OpenAI-compatible base URL, e.g. https://api.openai.com/v1 (final request path is /v1/images/edits). +api_url = "" +# zh: API Key。 +# en: API key. +api_key = "" +# zh: 模型名称,空则回退到 [models.image_gen] 的 model_name。 +# en: Model name, empty falls back to [models.image_gen].model_name. +model_name = "" + +# zh: 额外请求体参数(可选)。 +# en: Extra request-body params (optional). +[models.image_edit.request_params] + # zh: 本地知识库配置。 # en: Local knowledge base settings. [knowledge] diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index 7a11e27b..40fda0e7 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -584,6 +584,7 @@ class Config: # 生图工具配置 image_gen: ImageGenConfig models_image_gen: ImageGenModelConfig + models_image_edit: ImageGenModelConfig _allowed_group_ids_set: set[int] = dataclass_field( default_factory=set, init=False, @@ -1332,6 +1333,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi cognitive = cls._parse_cognitive_config(data) naga = cls._parse_naga_config(data) models_image_gen = cls._parse_image_gen_model_config(data) + models_image_edit = cls._parse_image_edit_model_config(data) image_gen = cls._parse_image_gen_config(data) if strict: @@ -1479,6 +1481,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi naga=naga, image_gen=image_gen, models_image_gen=models_image_gen, + models_image_edit=models_image_edit, ) @property @@ -2498,6 +2501,37 @@ def _parse_image_gen_model_config(data: dict[str, Any]) -> ImageGenModelConfig: request_params=_get_model_request_params(data, "image_gen"), ) + @staticmethod + def _parse_image_edit_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_edit] 参考图生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_url"), + "IMAGE_EDIT_MODEL_API_URL", + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_key"), + "IMAGE_EDIT_MODEL_API_KEY", + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, + ("models", "image_edit", "model_name"), + "IMAGE_EDIT_MODEL_NAME", + ), + "", + ), + request_params=_get_model_request_params(data, "image_edit"), + ) + @staticmethod def _parse_image_gen_config(data: dict[str, Any]) -> ImageGenConfig: """解析 [image_gen] 生图工具配置""" diff --git a/src/Undefined/skills/agents/entertainment_agent/prompt.md b/src/Undefined/skills/agents/entertainment_agent/prompt.md index cff647b3..d74c249c 100644 --- a/src/Undefined/skills/agents/entertainment_agent/prompt.md +++ b/src/Undefined/skills/agents/entertainment_agent/prompt.md @@ -7,6 +7,7 @@ - 输出轻松友好,但不要过度承诺或编造。 - 如果工具返回了图片 UID,且用户需要图文并茂的结果,优先在最终回复里用 `` 做图文混排,而不是单独口头描述“我发图了”。 - `` 只能引用当前会话里真实存在的图片 UID,不能臆造,也不要改写成 Markdown 图片语法。 +- 如果用户明确要求“参考这张图画”“照这个风格重画”“基于这些图生成”,优先调用 `ai_draw_one` 并传入 `reference_image_uids`,不要把图片内容重新手写成长段描述后再当纯文本生图。 边界提醒: - 正经资讯或需要核验的信息,引导至 info_agent / web_agent。 diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json index e6d8b525..63d8531c 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/config.json @@ -32,6 +32,13 @@ "description": "图片交付方式:embed 返回可插入回复的图片 UID;send 立即发送到当前或显式指定目标", "enum": ["embed", "send"] }, + "reference_image_uids": { + "type": "array", + "description": "参考图 UID 列表。传入后会改走参考图生图(images/edits)模式,UID 必须是当前会话内可访问的图片 UID。", + "items": { + "type": "string" + } + }, "n": { "type": "integer", "description": "生成图片数量(仅 models 模式生效,1 到 10)" diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py index 0c51bcef..ec93c4d6 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py @@ -9,7 +9,10 @@ import base64 import binascii +import json import logging +import mimetypes +from pathlib import Path import time import uuid from dataclasses import dataclass @@ -33,6 +36,7 @@ "1024x1024", ) _ALLOWED_IMAGE_RESPONSE_FORMATS = ("url", "b64_json", "base64") +_MAX_REFERENCE_IMAGE_UIDS = 16 _IMAGE_GEN_MODERATION_PROMPT: str | None = None @@ -272,6 +276,69 @@ def _build_openai_models_request_body( return body +def _coerce_reference_image_uids(value: Any) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + text = value.strip() + return [text] if text else [] + if not isinstance(value, list): + return [] + resolved: list[str] = [] + seen: set[str] = set() + for item in value: + text = str(item or "").strip() + if not text or text in seen: + continue + seen.add(text) + resolved.append(text) + return resolved + + +def _stringify_multipart_value(value: Any) -> str: + if value is None: + return "" + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, (str, int, float)): + return str(value) + return json.dumps(value, ensure_ascii=False) + + +def _build_openai_models_edit_form( + *, + prompt: str, + model_name: str, + size: str, + quality: str, + style: str, + response_format: str, + n: int | None, + extra_params: dict[str, Any], +) -> dict[str, str]: + from Undefined.utils.request_params import merge_request_params + + body = merge_request_params(extra_params) + body["prompt"] = prompt + if n is not None: + body["n"] = n + else: + body.setdefault("n", 1) + if model_name: + body["model"] = model_name + if size: + body["size"] = size + if quality: + body["quality"] = quality + if style: + body["style"] = style + if response_format: + body["response_format"] = response_format + else: + body.setdefault("response_format", "base64") + return {key: _stringify_multipart_value(value) for key, value in body.items()} + + def _validate_openai_models_request_body(body: dict[str, Any]) -> str | None: size = str(body.get("size", "") or "").strip() if size and size not in _ALLOWED_MODELS_IMAGE_SIZES: @@ -453,6 +520,145 @@ async def _call_openai_models( return generated_image +def _guess_upload_media_type(path: Path) -> str: + guessed, _ = mimetypes.guess_type(path.name) + return str(guessed or "application/octet-stream") + + +async def _resolve_reference_image_paths( + reference_image_uids: list[str], + context: dict[str, Any], +) -> tuple[list[Path] | None, str | None]: + if not reference_image_uids: + return [], None + if len(reference_image_uids) > _MAX_REFERENCE_IMAGE_UIDS: + return ( + None, + f"reference_image_uids 最多支持 {_MAX_REFERENCE_IMAGE_UIDS} 张参考图", + ) + + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + if attachment_registry is None or not scope_key: + return None, "当前会话未提供附件注册能力,无法解析参考图 UID" + + resolved_paths: list[Path] = [] + for uid in reference_image_uids: + record = attachment_registry.resolve(uid, scope_key) + if record is None: + return None, f"参考图 UID 不存在或不属于当前会话:{uid}" + if str(getattr(record, "media_type", "") or "").strip().lower() != "image": + return None, f"参考图 UID 不是图片:{uid}" + local_path = str(getattr(record, "local_path", "") or "").strip() + if not local_path: + return None, f"参考图 UID 缺少本地缓存文件:{uid}" + path = Path(local_path) + if not path.is_file(): + return None, f"参考图 UID 的本地缓存文件不存在:{uid}" + resolved_paths.append(path) + return resolved_paths, None + + +async def _call_openai_models_edit( + *, + prompt: str, + api_url: str, + api_key: str, + model_name: str, + size: str, + quality: str, + style: str, + response_format: str, + n: int | None, + timeout_val: float, + reference_image_paths: list[Path], + extra_params: dict[str, Any], + context: dict[str, Any], +) -> _GeneratedImagePayload | str: + form_data = _build_openai_models_edit_form( + prompt=prompt, + model_name=model_name, + size=size, + quality=quality, + style=style, + response_format=response_format, + n=n, + extra_params=extra_params, + ) + validation_error = _validate_openai_models_request_body( + {key: value for key, value in form_data.items()} + ) + if validation_error: + return validation_error + + base_url = api_url.rstrip("/") + if not base_url.endswith("/v1"): + base_url = f"{base_url}/v1" + url = f"{base_url}/images/edits" + + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + file_handles: list[Any] = [] + files: list[tuple[str, tuple[str, Any, str]]] = [] + try: + for path in reference_image_paths: + file_handle = path.open("rb") + file_handles.append(file_handle) + files.append( + ( + "image", + ( + path.name, + file_handle, + _guess_upload_media_type(path), + ), + ) + ) + + try: + response = await request_with_retry( + "POST", + url, + data=form_data, + files=files, + headers=headers or None, + timeout=timeout_val, + context=context, + ) + except httpx.HTTPStatusError as exc: + message = _format_upstream_error_message(exc.response) + return f"参考图生图请求失败: HTTP {exc.response.status_code} {message}" + except httpx.TimeoutException: + return f"参考图生图请求超时({timeout_val:.0f}s)" + except httpx.RequestError as exc: + return f"参考图生图请求失败: {exc}" + + try: + data = response.json() + except Exception: + return f"API 返回错误 (非JSON): {response.text[:100]}" + + generated_image = _parse_generated_image(data) + if generated_image is None: + logger.error(f"参考图生图 API 返回 (未找到图片内容): {data}") + return f"API 返回原文 (错误:未找到图片内容): {data}" + + logger.info(f"参考图生图 API 返回: {data}") + if generated_image.image_url: + logger.info(f"提取图片链接: {generated_image.image_url}") + elif generated_image.image_bytes is not None: + logger.info("提取图片字节: bytes=%s", len(generated_image.image_bytes)) + return generated_image + finally: + for handle in file_handles: + try: + handle.close() + except Exception: + pass + + async def _download_and_send( image_url: str, target_id: int | str, @@ -559,12 +765,16 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: style_arg: str | None = args.get("style") response_format_arg: str | None = args.get("response_format") n_arg = args.get("n") + reference_image_uids = _coerce_reference_image_uids( + args.get("reference_image_uids") + ) delivery = str(args.get("delivery", "embed") or "embed").strip().lower() target_id: int | str | None = args.get("target_id") message_type_arg: str | None = args.get("message_type") cfg = get_config(strict=False).image_gen gen_cfg = get_config(strict=False).models_image_gen + edit_cfg = get_config(strict=False).models_image_edit chat_cfg = get_config(strict=False).chat_model provider = cfg.provider @@ -585,15 +795,32 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: return moderation_error if provider == "xingzhige": + if reference_image_uids: + return "图片生成失败:xingzhige provider 不支持参考图生图" prompt = prompt_arg or "" size = size_arg or cfg.xingzhige_size generated_result = await _call_xingzhige(prompt, size, context) elif provider == "models": prompt = prompt_arg or "" - # 降级到 models.image_gen 配置,未填则降级到 chat_model - api_url = gen_cfg.api_url or chat_cfg.api_url - api_key = gen_cfg.api_key or chat_cfg.api_key - model_name = str(gen_cfg.model_name or "").strip() + use_reference_images = bool(reference_image_uids) + selected_cfg = edit_cfg if use_reference_images else gen_cfg + fallback_cfg = gen_cfg if use_reference_images else None + # 降级到独立的 image 配置,未填再降级到 chat_model + api_url = ( + selected_cfg.api_url + or (fallback_cfg.api_url if fallback_cfg is not None else "") + or chat_cfg.api_url + ) + api_key = ( + selected_cfg.api_key + or (fallback_cfg.api_key if fallback_cfg is not None else "") + or chat_cfg.api_key + ) + model_name = str( + selected_cfg.model_name + or (fallback_cfg.model_name if fallback_cfg is not None else "") + or "" + ).strip() size = str(size_arg or cfg.openai_size or "").strip() quality = str(quality_arg or cfg.openai_quality or "").strip() style = str(style_arg or cfg.openai_style or "").strip() @@ -607,24 +834,60 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: return f"n 无效:{n_arg}。必须是 1 到 10 的整数" if not api_url: - return "图片生成失败:未配置 models.image_gen.api_url" + return ( + "图片生成失败:未配置 models.image_edit.api_url" + if use_reference_images + else "图片生成失败:未配置 models.image_gen.api_url" + ) if not api_key: - return "图片生成失败:未配置 models.image_gen.api_key" + return ( + "图片生成失败:未配置 models.image_edit.api_key" + if use_reference_images + else "图片生成失败:未配置 models.image_gen.api_key" + ) used_model = model_name or "openai-image-gen" - generated_result = await _call_openai_models( - prompt=prompt, - api_url=api_url, - api_key=api_key, - model_name=model_name, - size=size, - quality=quality, - style=style, - response_format=response_format, - n=n_value, - timeout_val=timeout_val, - context=context, - ) + if use_reference_images: + from Undefined.utils.request_params import merge_request_params + + ( + reference_image_paths, + reference_error, + ) = await _resolve_reference_image_paths(reference_image_uids, context) + if reference_error: + return reference_error + generated_result = await _call_openai_models_edit( + prompt=prompt, + api_url=api_url, + api_key=api_key, + model_name=model_name, + size=size, + quality=quality, + style=style, + response_format=response_format, + n=n_value, + timeout_val=timeout_val, + reference_image_paths=reference_image_paths or [], + extra_params=merge_request_params( + gen_cfg.request_params, + edit_cfg.request_params, + ), + context=context, + ) + else: + generated_result = await _call_openai_models( + prompt=prompt, + api_url=api_url, + api_key=api_key, + model_name=model_name, + size=size, + quality=quality, + style=style, + response_format=response_format, + n=n_value, + timeout_val=timeout_val, + context=context, + ) else: return ( f"未知的生图 provider: {provider}," diff --git a/src/Undefined/skills/http_client.py b/src/Undefined/skills/http_client.py index 70f71431..b2aa11fe 100644 --- a/src/Undefined/skills/http_client.py +++ b/src/Undefined/skills/http_client.py @@ -27,6 +27,7 @@ async def request_with_retry( params: dict[str, Any] | None = None, json_data: Any | None = None, data: Any | None = None, + files: Any | None = None, headers: dict[str, str] | None = None, timeout: float | None = None, default_timeout: float = 480.0, @@ -55,6 +56,7 @@ async def request_with_retry( params=params, json=json_data, data=data, + files=files, headers=headers, ) if ( diff --git a/tests/test_ai_draw_one_handler.py b/tests/test_ai_draw_one_handler.py index 7c368bc9..6b713454 100644 --- a/tests/test_ai_draw_one_handler.py +++ b/tests/test_ai_draw_one_handler.py @@ -41,6 +41,12 @@ def _make_runtime_config(*, request_params: dict[str, Any] | None = None) -> Any model_name="grok-imagine-1.0", request_params=request_params or {}, ), + models_image_edit=SimpleNamespace( + api_url="https://edit.example.com", + api_key="sk-edit", + model_name="grok-edit-1.0", + request_params={}, + ), chat_model=SimpleNamespace( api_url="https://chat.example.com", api_key="sk-chat", @@ -441,3 +447,108 @@ async def _send_image( assert sent["target_id"] == 10001 assert sent["message_type"] == "group" assert Path(sent["file_path"]).read_bytes() == _PNG_BYTES + + +@pytest.mark.asyncio +async def test_execute_models_reference_images_uses_edit_endpoint_and_config( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + runtime_config = _make_runtime_config(request_params={}) + runtime_config.models_image_edit.request_params = {"background": "transparent"} + monkeypatch.setattr( + "Undefined.config.get_config", + lambda strict=False: runtime_config, + ) + + payload_base64 = base64.b64encode(_PNG_BYTES).decode("ascii") + seen_request: dict[str, Any] = {} + + class _FakeResponse: + text = "" + + def json(self) -> dict[str, Any]: + return {"data": [{"base64": payload_base64}]} + + async def _fake_request_with_retry( + method: str, + url: str, + **kwargs: Any, + ) -> _FakeResponse: + seen_request["method"] = method + seen_request["url"] = url + seen_request["data"] = kwargs.get("data") + seen_request["files"] = kwargs.get("files") + return _FakeResponse() + + monkeypatch.setattr(ai_draw_handler, "request_with_retry", _fake_request_with_retry) + + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + record = await registry.register_bytes( + "group:10001", + _PNG_BYTES, + kind="image", + display_name="ref.png", + source_kind="test", + ) + + result = await ai_draw_handler.execute( + { + "prompt": "use this as reference", + "size": "1024x1024", + "reference_image_uids": [record.uid], + }, + { + "attachment_registry": registry, + "request_type": "group", + "group_id": 10001, + }, + ) + + assert result.startswith('已生成图片,可在回复中插入 None: + monkeypatch.setattr( + "Undefined.config.get_config", + lambda strict=False: _make_runtime_config(request_params={}), + ) + + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + record = await registry.register_bytes( + "group:10001", + b"hello", + kind="file", + display_name="demo.txt", + source_kind="test", + ) + + result = await ai_draw_handler.execute( + { + "prompt": "use this as reference", + "reference_image_uids": [record.uid], + }, + { + "attachment_registry": registry, + "request_type": "group", + "group_id": 10001, + }, + ) + + assert result == f"参考图 UID 不是图片:{record.uid}" diff --git a/tests/test_config_request_params.py b/tests/test_config_request_params.py index 82ba2be0..528a78d7 100644 --- a/tests/test_config_request_params.py +++ b/tests/test_config_request_params.py @@ -114,6 +114,22 @@ def test_model_request_params_load_inherit_and_new_transport_fields( [models.rerank.request_params] priority = "high" + +[models.image_gen] +api_url = "https://image.example.com/v1" +api_key = "sk-image" +model_name = "gpt-image-gen" + +[models.image_gen.request_params] +temperature = 0.8 + +[models.image_edit] +api_url = "https://edit.example.com/v1" +api_key = "sk-image-edit" +model_name = "gpt-image-edit" + +[models.image_edit.request_params] +background = "transparent" """, ) @@ -194,6 +210,14 @@ def test_model_request_params_load_inherit_and_new_transport_fields( "metadata": {"source": "embed"}, } assert cfg.rerank_model.request_params == {"priority": "high"} + assert cfg.models_image_gen.api_url == "https://image.example.com/v1" + assert cfg.models_image_gen.api_key == "sk-image" + assert cfg.models_image_gen.model_name == "gpt-image-gen" + assert cfg.models_image_gen.request_params == {"temperature": 0.8} + assert cfg.models_image_edit.api_url == "https://edit.example.com/v1" + assert cfg.models_image_edit.api_key == "sk-image-edit" + assert cfg.models_image_edit.model_name == "gpt-image-edit" + assert cfg.models_image_edit.request_params == {"background": "transparent"} def test_naga_model_request_params_override_security_defaults(tmp_path: Path) -> None: From 73f380144c95d71ea4c21ece43e69c13d5814a4c Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Wed, 1 Apr 2026 18:24:46 +0800 Subject: [PATCH 10/21] fix(cognitive): refresh profile display names on rename --- src/Undefined/cognitive/service.py | 140 +++++++++++++++++++++++++++++ src/Undefined/handlers.py | 73 +++++++++++++++ tests/test_cognitive_service.py | 98 +++++++++++++++++++- 3 files changed, 309 insertions(+), 2 deletions(-) diff --git a/src/Undefined/cognitive/service.py b/src/Undefined/cognitive/service.py index c79a1aa7..0157daf3 100644 --- a/src/Undefined/cognitive/service.py +++ b/src/Undefined/cognitive/service.py @@ -116,6 +116,77 @@ def _resolve_auto_request_type( return "" +def _parse_profile_markdown(markdown: str) -> tuple[dict[str, Any], str] | None: + text = str(markdown or "") + if not text.startswith("---"): + return None + try: + import yaml + + parts = text[3:].split("---", 1) + if len(parts) != 2: + return None + frontmatter = yaml.safe_load(parts[0]) + if not isinstance(frontmatter, dict): + return None + body = parts[1].lstrip("\n") + return frontmatter, body + except Exception: + return None + + +def _serialize_profile_markdown(frontmatter: dict[str, Any], body: str) -> str: + import yaml + + return f"---\n{yaml.dump(frontmatter, allow_unicode=True)}---\n{body}" + + +def _normalize_profile_tags(value: Any) -> list[str]: + if not isinstance(value, list): + return [] + return [str(item).strip() for item in value if str(item).strip()] + + +def _current_profile_name(entity_type: str, frontmatter: dict[str, Any]) -> str: + if entity_type == "user": + return str(frontmatter.get("nickname") or frontmatter.get("name") or "").strip() + return str(frontmatter.get("group_name") or frontmatter.get("name") or "").strip() + + +def _build_profile_vector_payload( + *, + entity_type: str, + entity_id: str, + effective_name: str, + tags: list[str], + summary: str, +) -> tuple[str, dict[str, Any]]: + profile_doc_lines: list[str] = [] + if entity_type == "user": + profile_doc_lines.append(f"昵称: {effective_name}") + profile_doc_lines.append(f"QQ号: {entity_id}") + else: + profile_doc_lines.append(f"群名: {effective_name}") + profile_doc_lines.append(f"群号: {entity_id}") + if tags: + profile_doc_lines.append(f"标签: {', '.join(tags)}") + profile_doc_lines.append(summary) + profile_doc = "\n".join(line for line in profile_doc_lines if line.strip()) + + metadata: dict[str, Any] = { + "entity_type": entity_type, + "entity_id": entity_id, + "name": effective_name, + } + if entity_type == "user": + metadata["nickname"] = effective_name + metadata["qq"] = entity_id + else: + metadata["group_name"] = effective_name + metadata["group_id"] = entity_id + return profile_doc, metadata + + class CognitiveService: def __init__( self, @@ -169,6 +240,75 @@ async def _prepare_query_embedding(self, query: str) -> list[float] | None: def enabled(self) -> bool: return bool(self._config_getter().enabled) + async def sync_profile_display_name( + self, + *, + entity_type: str, + entity_id: str, + preferred_name: str, + ) -> bool: + normalized_entity_type = str(entity_type or "").strip().lower() + normalized_entity_id = str(entity_id or "").strip() + normalized_name = str(preferred_name or "").strip() + if normalized_entity_type not in {"user", "group"}: + return False + if not normalized_entity_id or not normalized_name: + return False + if self._profile_storage is None or self._vector_store is None: + return False + + existing = await self._profile_storage.read_profile( + normalized_entity_type, + normalized_entity_id, + ) + if not existing: + return False + + parsed = _parse_profile_markdown(existing) + if parsed is None: + return False + frontmatter, summary = parsed + current_name = _current_profile_name(normalized_entity_type, frontmatter) + if current_name == normalized_name: + return False + + frontmatter["name"] = normalized_name + frontmatter["updated_at"] = datetime.now().isoformat() + if normalized_entity_type == "user": + frontmatter["nickname"] = normalized_name + frontmatter["qq"] = normalized_entity_id + else: + frontmatter["group_name"] = normalized_name + frontmatter["group_id"] = normalized_entity_id + + updated_markdown = _serialize_profile_markdown(frontmatter, summary) + await self._profile_storage.write_profile( + normalized_entity_type, + normalized_entity_id, + updated_markdown, + ) + + profile_doc, profile_metadata = _build_profile_vector_payload( + entity_type=normalized_entity_type, + entity_id=normalized_entity_id, + effective_name=normalized_name, + tags=_normalize_profile_tags(frontmatter.get("tags")), + summary=summary, + ) + await self._vector_store.upsert_profile( + f"{normalized_entity_type}:{normalized_entity_id}", + profile_doc, + profile_metadata, + ) + logger.info( + "[认知服务] 已刷新侧写展示名: entity_type=%s entity_id=%s old=%s new=%s", + normalized_entity_type, + normalized_entity_id, + current_name, + normalized_name, + ) + return True + @staticmethod def _uid_candidates(user_id: str, sender_id: str) -> list[str]: values: list[str] = [] diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index decff51a..e5875929 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -156,6 +156,37 @@ async def _collect_message_attachments( ) return result.attachments + async def _refresh_profile_display_names( + self, + *, + sender_id: int | None = None, + sender_name: str = "", + group_id: int | None = None, + group_name: str = "", + ) -> None: + ai_client = getattr(self, "ai", None) + cognitive_service = getattr(ai_client, "_cognitive_service", None) + if not cognitive_service or not getattr(cognitive_service, "enabled", False): + return + + if sender_id and sender_name.strip(): + await cognitive_service.sync_profile_display_name( + entity_type="user", + entity_id=str(sender_id), + preferred_name=sender_name.strip(), + ) + if group_id and group_name.strip(): + await cognitive_service.sync_profile_display_name( + entity_type="group", + entity_id=str(group_id), + preferred_name=group_name.strip(), + ) + + def _can_refresh_profile_display_names(self) -> bool: + ai_client = getattr(self, "ai", None) + cognitive_service = getattr(ai_client, "_cognitive_service", None) + return bool(cognitive_service and getattr(cognitive_service, "enabled", False)) + async def handle_message(self, event: dict[str, Any]) -> None: """处理收到的消息事件""" if logger.isEnabledFor(logging.DEBUG): @@ -298,6 +329,15 @@ async def handle_message(self, event: dict[str, Any]) -> None: user_name or private_sender_nickname, safe_text[:100], ) + resolved_private_name = (user_name or private_sender_nickname or "").strip() + if resolved_private_name and self._can_refresh_profile_display_names(): + self._spawn_background_task( + f"profile_name_refresh_private:{private_sender_id}", + self._refresh_profile_display_names( + sender_id=private_sender_id, + sender_name=resolved_private_name, + ), + ) # 保存私聊消息到历史记录(保存处理后的内容) # 使用新的工具函数解析内容 @@ -439,6 +479,19 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name = group_info.get("group_name", "") except Exception as e: logger.warning(f"获取群聊名失败: {e}") + resolved_group_sender_name = (sender_card or sender_nickname or "").strip() + if (resolved_group_sender_name or str(group_name or "").strip()) and ( + self._can_refresh_profile_display_names() + ): + self._spawn_background_task( + f"profile_name_refresh_group:{group_id}:{sender_id}", + self._refresh_profile_display_names( + sender_id=sender_id, + sender_name=resolved_group_sender_name, + group_id=group_id, + group_name=str(group_name or "").strip(), + ), + ) # 使用新的 utils parsed_content = await parse_message_content_for_history( @@ -598,6 +651,14 @@ async def _record_private_poke_history( display_name = sender_nickname or user_name or f"QQ{user_id}" normalized_user_name = user_name or display_name poke_text = _format_poke_history_text(display_name, user_id) + if display_name.strip() and self._can_refresh_profile_display_names(): + self._spawn_background_task( + f"profile_name_refresh_private_poke:{user_id}", + self._refresh_profile_display_names( + sender_id=user_id, + sender_name=display_name, + ), + ) try: await self.history_manager.add_private_message( @@ -667,6 +728,18 @@ async def _record_group_poke_history( display_name = sender_card or sender_nickname or f"QQ{sender_id}" poke_text = _format_poke_history_text(display_name, sender_id) normalized_group_name = group_name or f"群{group_id}" + if (display_name.strip() or normalized_group_name.strip()) and ( + self._can_refresh_profile_display_names() + ): + self._spawn_background_task( + f"profile_name_refresh_group_poke:{group_id}:{sender_id}", + self._refresh_profile_display_names( + sender_id=sender_id, + sender_name=display_name, + group_id=group_id, + group_name=normalized_group_name, + ), + ) try: await self.history_manager.add_group_message( diff --git a/tests/test_cognitive_service.py b/tests/test_cognitive_service.py index 724f2651..b3194eea 100644 --- a/tests/test_cognitive_service.py +++ b/tests/test_cognitive_service.py @@ -21,6 +21,7 @@ class _FakeVectorStore: def __init__(self) -> None: self.last_event_kwargs: dict[str, Any] | None = None self.last_profile_kwargs: dict[str, Any] | None = None + self.last_upsert_profile: tuple[str, str, dict[str, Any]] | None = None self.event_calls: list[dict[str, Any]] = [] self.event_resolver: Callable[[dict[str, Any]], list[dict[str, Any]]] | None = ( None @@ -45,14 +46,35 @@ async def query_profiles( self.last_profile_kwargs = dict(kwargs) return [] + async def upsert_profile( + self, + profile_id: str, + document: str, + metadata: dict[str, Any], + ) -> None: + self.last_upsert_profile = (profile_id, document, metadata) + class _FakeProfileStorage: + def __init__(self, initial_profile: str | None = None) -> None: + self.profile = initial_profile + self.last_write: tuple[str, str, str] | None = None + async def read_profile( self, _entity_type: str, _entity_id: str, ) -> str | None: - return None + return self.profile + + async def write_profile( + self, + entity_type: str, + entity_id: str, + content: str, + ) -> None: + self.profile = content + self.last_write = (entity_type, entity_id, content) class _FakeRetrievalRuntime: @@ -518,8 +540,80 @@ def test_merge_weighted_events_preserves_scope_rank_order() -> None: [(scoped_events, 1.0)], top_k=2, ) - assert [item["document"] for item in merged] == [ "更新但稍弱相似度", "更老但更高相似度", ] + + +@pytest.mark.asyncio +async def test_sync_profile_display_name_updates_existing_profile_and_vector() -> None: + existing_profile = """--- +entity_type: user +entity_id: "12345" +name: 旧昵称 +nickname: 旧昵称 +tags: + - 开发者 +updated_at: "2026-04-01T00:00:00" +--- +喜欢 Python +""" + vector_store = _FakeVectorStore() + profile_storage = _FakeProfileStorage(existing_profile) + service = CognitiveService( + config_getter=lambda: SimpleNamespace(enabled=True), + vector_store=vector_store, + job_queue=_FakeJobQueue(), + profile_storage=profile_storage, + reranker=None, + ) + + updated = await service.sync_profile_display_name( + entity_type="user", + entity_id="12345", + preferred_name="新昵称", + ) + + assert updated is True + assert profile_storage.last_write is not None + assert "name: 新昵称" in profile_storage.last_write[2] + assert "nickname: 新昵称" in profile_storage.last_write[2] + assert vector_store.last_upsert_profile is not None + profile_id, document, metadata = vector_store.last_upsert_profile + assert profile_id == "user:12345" + assert "昵称: 新昵称" in document + assert metadata["name"] == "新昵称" + assert metadata["nickname"] == "新昵称" + + +@pytest.mark.asyncio +async def test_sync_profile_display_name_noops_when_name_unchanged() -> None: + existing_profile = """--- +entity_type: group +entity_id: "10001" +name: 测试群 +group_name: 测试群 +updated_at: "2026-04-01T00:00:00" +--- +一个群聊 +""" + vector_store = _FakeVectorStore() + profile_storage = _FakeProfileStorage(existing_profile) + service = CognitiveService( + config_getter=lambda: SimpleNamespace(enabled=True), + vector_store=vector_store, + job_queue=_FakeJobQueue(), + profile_storage=profile_storage, + reranker=None, + ) + + updated = await service.sync_profile_display_name( + entity_type="group", + entity_id="10001", + preferred_name="测试群", + ) + + assert updated is False + assert profile_storage.last_write is None + assert vector_store.last_upsert_profile is None From 91e77bccc42a794a02ee083700bf35845f7ea6ed Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Thu, 2 Apr 2026 10:54:48 +0800 Subject: [PATCH 11/21] feat(multimodal): adapt image tools to attachment uid mechanism Add delivery parameter (embed/send, default embed) to render_markdown, render_latex, render_html, get_picture, and minecraft_skin so they register images via AttachmentRegistry and return tags. Update wenchang_dijun to register sign images by UID instead of returning raw URLs. Add new fetch_image_uid tool for converting arbitrary image URLs into attachment UIDs. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../tools/minecraft_skin/config.json | 13 +- .../tools/minecraft_skin/handler.py | 64 ++++- .../tools/wenchang_dijun/handler.py | 23 +- .../skills/tools/fetch_image_uid/config.json | 21 ++ .../skills/tools/fetch_image_uid/handler.py | 46 ++++ .../skills/tools/get_picture/config.json | 13 +- .../skills/tools/get_picture/handler.py | 223 ++++++++++++++---- .../toolsets/render/render_html/config.json | 13 +- .../toolsets/render/render_html/handler.py | 74 +++++- .../toolsets/render/render_latex/config.json | 13 +- .../toolsets/render/render_latex/handler.py | 93 ++++++-- .../render/render_markdown/config.json | 13 +- .../render/render_markdown/handler.py | 87 +++++-- 13 files changed, 577 insertions(+), 119 deletions(-) create mode 100644 src/Undefined/skills/tools/fetch_image_uid/config.json create mode 100644 src/Undefined/skills/tools/fetch_image_uid/handler.py diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/minecraft_skin/config.json b/src/Undefined/skills/agents/entertainment_agent/tools/minecraft_skin/config.json index 0dfcc7bc..411d6c1e 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/minecraft_skin/config.json +++ b/src/Undefined/skills/agents/entertainment_agent/tools/minecraft_skin/config.json @@ -2,7 +2,7 @@ "type": "function", "function": { "name": "minecraft_skin", - "description": "获取 Minecraft 玩家皮肤/头像。", + "description": "获取 Minecraft 玩家皮肤/头像。默认返回可嵌入回复的图片 UID(embed),也可直接发送到目标(send)。", "parameters": { "type": "object", "properties": { @@ -27,17 +27,22 @@ "type": "integer", "description": "体型大小 (1-10)" }, + "delivery": { + "type": "string", + "description": "图片交付方式:embed 返回可插入回复的图片 UID;send 立即发送到目标", + "enum": ["embed", "send"] + }, "target_id": { "type": "integer", - "description": "发送目标的 ID (群号或 QQ 号)" + "description": "发送目标的 ID (群号或 QQ 号,仅 delivery=send 时需要,不提供则从当前会话推断)" }, "message_type": { "type": "string", - "description": "消息类型 (group 或 private)", + "description": "消息类型 (group 或 private,仅 delivery=send 时需要,不提供则从当前会话推断)", "enum": ["group", "private"] } }, - "required": ["name", "target_id", "message_type"] + "required": ["name"] } } } \ No newline at end of file diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/minecraft_skin/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/minecraft_skin/handler.py index 8204ffd4..d19162cb 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/minecraft_skin/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/minecraft_skin/handler.py @@ -1,13 +1,36 @@ +from __future__ import annotations + from typing import Any, Dict import logging import uuid +from Undefined.attachments import scope_from_context from Undefined.skills.http_client import request_with_retry from Undefined.skills.http_config import get_request_timeout, get_xingzhige_url logger = logging.getLogger(__name__) +def _resolve_send_target( + target_id: Any, + message_type: Any, + context: Dict[str, Any], +) -> tuple[int | str | None, str | None, str | None]: + """从参数或 context 推断发送目标。""" + if target_id is not None and message_type is not None: + return target_id, message_type, None + request_type = str(context.get("request_type", "") or "").strip().lower() + if request_type == "group": + gid = context.get("group_id") + if gid is not None: + return gid, "group", None + if request_type == "private": + uid = context.get("user_id") + if uid is not None: + return uid, "private", None + return None, None, "获取成功,但缺少发送目标参数" + + async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: """获取指定我的世界(Minecraft)正版用户的皮肤图片链接""" name = args.get("name") @@ -15,9 +38,13 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: overlay = args.get("overlay", True) size = args.get("size", 160) scale = args.get("scale", 6) + delivery = str(args.get("delivery", "embed") or "embed").strip().lower() target_id = args.get("target_id") message_type = args.get("message_type") + if delivery not in {"embed", "send"}: + return f"delivery 无效:{delivery}。仅支持 embed 或 send" + url = get_xingzhige_url("/API/get_Minecraft_skins/") params = { "name": name, @@ -43,7 +70,7 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if "application/json" in content_type: return f"获取失败: {response.text}" - # 假设是图片 + # 保存图片到缓存 filename = f"mc_skin_{uuid.uuid4().hex[:8]}.png" from Undefined.utils.paths import IMAGE_CACHE_DIR, ensure_dir @@ -52,10 +79,41 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: with open(filepath, "wb") as f: f.write(response.content) + # 注册到附件系统 + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + record: Any = None + if attachment_registry is not None and scope_key: + try: + record = await attachment_registry.register_local_file( + scope_key, + filepath, + kind="image", + display_name=filename, + source_kind="minecraft_skin", + source_ref=f"minecraft_skin:{name}", + ) + except Exception as exc: + logger.warning("注册 Minecraft 皮肤到附件系统失败: %s", exc) + + if delivery == "embed": + if record is None: + return "获取成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" + return f'' + + # delivery == "send" + resolved_target_id, resolved_message_type, target_error = _resolve_send_target( + target_id, message_type, context + ) + if target_error or resolved_target_id is None or resolved_message_type is None: + return target_error or "获取成功,但缺少发送目标参数" + send_image_callback = context.get("send_image_callback") if send_image_callback: - await send_image_callback(target_id, message_type, str(filepath)) - return f"Minecraft 皮肤/头像已发送给 {message_type} {target_id}" + await send_image_callback( + resolved_target_id, resolved_message_type, str(filepath) + ) + return f"Minecraft 皮肤/头像已发送给 {resolved_message_type} {resolved_target_id}" return "发送图片回调未设置,图片已保存但无法发送。" except Exception as e: diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/wenchang_dijun/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/wenchang_dijun/handler.py index d4cf661a..c151319e 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/wenchang_dijun/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/wenchang_dijun/handler.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from typing import Any, Dict import logging import httpx +from Undefined.attachments import scope_from_context from Undefined.skills.http_client import get_json_with_retry from Undefined.skills.http_config import get_xxapi_url @@ -35,7 +38,25 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: result += f"【签文】\n{content}\n" if pic: - result += f"\n签文图片:{pic}" + # 尝试注册图片到附件系统 + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + if attachment_registry is not None and scope_key: + try: + record = await attachment_registry.register_remote_url( + scope_key, + pic, + kind="image", + display_name=f"wenchang_{fortune_id}.jpg", + source_kind="wenchang_dijun", + source_ref=pic, + ) + result += f'\n签文图片:' + except Exception as exc: + logger.warning("注册文昌帝君签文图片失败: %s", exc) + result += f"\n签文图片:{pic}" + else: + result += f"\n签文图片:{pic}" return result diff --git a/src/Undefined/skills/tools/fetch_image_uid/config.json b/src/Undefined/skills/tools/fetch_image_uid/config.json new file mode 100644 index 00000000..87e5a05c --- /dev/null +++ b/src/Undefined/skills/tools/fetch_image_uid/config.json @@ -0,0 +1,21 @@ +{ + "type": "function", + "function": { + "name": "fetch_image_uid", + "description": "从 URL 获取图片并注册到附件系统,返回可在回复中嵌入的图片 UID。仅支持图片类型(PNG, JPEG, GIF, WebP, BMP)。", + "parameters": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "图片 URL(必须是 http/https 链接)" + }, + "display_name": { + "type": "string", + "description": "图片的显示名称(可选,默认从 URL 推断)" + } + }, + "required": ["url"] + } + } +} diff --git a/src/Undefined/skills/tools/fetch_image_uid/handler.py b/src/Undefined/skills/tools/fetch_image_uid/handler.py new file mode 100644 index 00000000..11724cf8 --- /dev/null +++ b/src/Undefined/skills/tools/fetch_image_uid/handler.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict + +from Undefined.attachments import scope_from_context + +logger = logging.getLogger(__name__) + +_IMAGE_MIME_PREFIX = "image/" + + +async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: + """从 URL 获取图片并注册到附件系统,返回图片 UID。""" + url = str(args.get("url", "") or "").strip() + display_name = str(args.get("display_name", "") or "").strip() or None + + if not url: + return "URL 不能为空" + if not url.startswith(("http://", "https://")): + return "URL 必须是 http 或 https 链接" + + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + if attachment_registry is None or not scope_key: + return "当前会话不支持附件注册" + + try: + record = await attachment_registry.register_remote_url( + scope_key, + url, + kind="image", + display_name=display_name, + source_kind="fetch_image_uid", + source_ref=url, + ) + except Exception as exc: + logger.exception("fetch_image_uid 注册失败: %s", exc) + return f"获取图片失败:{exc}" + + # 验证是否为图片类型 + mime = str(getattr(record, "mime_type", "") or "").strip().lower() + if mime and not mime.startswith(_IMAGE_MIME_PREFIX): + return f"URL 内容不是图片类型(检测到 {mime}),仅支持图片" + + return f'' diff --git a/src/Undefined/skills/tools/get_picture/config.json b/src/Undefined/skills/tools/get_picture/config.json index d773604b..463179ba 100644 --- a/src/Undefined/skills/tools/get_picture/config.json +++ b/src/Undefined/skills/tools/get_picture/config.json @@ -2,18 +2,23 @@ "type": "function", "function": { "name": "get_picture", - "description": "获取指定数量的指定类型的图片并发送到群聊或指定私聊。支持白丝、黑丝、头像、JK、二次元、小姐姐、壁纸、原神、历史上的今天、4K图片、美腿十一种类型。二次元类型支持选择手机端或PC端。4K图片支持选择二次元或风景。默认获取二次元图片,默认使用PC端。", + "description": "获取指定数量的指定类型的图片。默认返回可嵌入回复的图片 UID(embed),也可直接发送到群聊或指定私聊(send)。支持白丝、黑丝、头像、JK、二次元、小姐姐、壁纸、原神、历史上的今天、4K图片、美腿十一种类型。二次元类型支持选择手机端或PC端。4K图片支持选择二次元或风景。默认获取二次元图片,默认使用PC端。", "parameters": { "type": "object", "properties": { + "delivery": { + "type": "string", + "description": "图片交付方式:embed 返回可插入回复的图片 UID;send 立即发送到目标", + "enum": ["embed", "send"] + }, "message_type": { "type": "string", - "description": "消息类型", + "description": "消息类型(仅 delivery=send 时需要,不提供则从当前会话推断)", "enum": ["group", "private"] }, "target_id": { "type": "integer", - "description": "目标 ID(群聊为群号,私聊为用户 QQ 号)" + "description": "目标 ID(群聊为群号,私聊为用户 QQ 号,仅 delivery=send 时需要,不提供则从当前会话推断)" }, "picture_type": { "type": "string", @@ -39,7 +44,7 @@ "default": "acg" } }, - "required": ["message_type", "target_id"] + "required": [] } } } \ No newline at end of file diff --git a/src/Undefined/skills/tools/get_picture/handler.py b/src/Undefined/skills/tools/get_picture/handler.py index bcb05e32..b13df227 100644 --- a/src/Undefined/skills/tools/get_picture/handler.py +++ b/src/Undefined/skills/tools/get_picture/handler.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Dict import logging import httpx @@ -5,6 +7,7 @@ import uuid from pathlib import Path +from Undefined.attachments import scope_from_context from Undefined.config import get_config from Undefined.skills.http_client import request_with_retry @@ -39,6 +42,20 @@ "meitui": "美腿", } +# 中文数字映射 +_CN_NUMS = { + 1: "一", + 2: "二", + 3: "三", + 4: "四", + 5: "五", + 6: "六", + 7: "七", + 8: "八", + 9: "九", + 10: "十", +} + def _get_xxapi_base_url() -> str: config = get_config(strict=False) @@ -52,7 +69,28 @@ def _get_timeout_seconds() -> float: return timeout if timeout > 0 else 480.0 +def _resolve_send_target( + target_id: Any, + message_type: Any, + context: Dict[str, Any], +) -> tuple[int | str | None, str | None, str | None]: + """从参数或 context 推断发送目标。""" + if target_id is not None and message_type is not None: + return target_id, message_type, None + request_type = str(context.get("request_type", "") or "").strip().lower() + if request_type == "group": + gid = context.get("group_id") + if gid is not None: + return gid, "group", None + if request_type == "private": + uid = context.get("user_id") + if uid is not None: + return uid, "private", None + return None, None, "获取成功,但缺少发送目标参数" + + async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: + delivery = str(args.get("delivery", "embed") or "embed").strip().lower() message_type = args.get("message_type") target_id = args.get("target_id") picture_type = args.get("picture_type", "acg") @@ -61,29 +99,23 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: fourk_type = args.get("fourk_type", "acg") # 参数验证 - if not message_type: - return "❌ 消息类型不能为空" - if message_type not in ["group", "private"]: - return "❌ 消息类型必须是 group(群聊)或 private(私聊)" - if not target_id: - return "❌ 目标 ID 不能为空" - if not isinstance(target_id, int): - return "❌ 目标 ID 必须是整数" + if delivery not in {"embed", "send"}: + return f"delivery 无效:{delivery}。仅支持 embed 或 send" + + if delivery == "send": + if message_type and message_type not in ("group", "private"): + return "消息类型必须是 group(群聊)或 private(私聊)" + if picture_type not in API_PATHS: - return f"❌ 不支持的图片类型: {picture_type}\n支持的类型: {', '.join(TYPE_NAMES.values())}" + return f"不支持的图片类型: {picture_type}\n支持的类型: {', '.join(TYPE_NAMES.values())}" if not isinstance(count, int): - return "❌ 图片数量必须是整数" + return "图片数量必须是整数" if count < 1 or count > 10: - return "❌ 图片数量必须在 1-10 之间" - if picture_type == "acg" and device not in ["pc", "wap"]: - return "❌ 设备类型必须是 pc(电脑端)或 wap(手机端)" - if picture_type == "random4kPic" and fourk_type not in ["acg", "wallpaper"]: - return "❌ 4K图片类型必须是 acg(二次元)或 wallpaper(风景)" - - # 获取发送图片回调 - send_image_callback = context.get("send_image_callback") - if not send_image_callback: - return "发送图片回调未设置" + return "图片数量必须在 1-10 之间" + if picture_type == "acg" and device not in ("pc", "wap"): + return "设备类型必须是 pc(电脑端)或 wap(手机端)" + if picture_type == "random4kPic" and fourk_type not in ("acg", "wallpaper"): + return "4K图片类型必须是 acg(二次元)或 wallpaper(风景)" # 构造请求参数 params: Dict[str, Any] = {"return": "json"} @@ -92,11 +124,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: elif picture_type == "random4kPic": params["type"] = fourk_type - # 获取图片 - success_count = 0 - fail_count = 0 - local_image_paths: list[str] = [] - # 创建图片保存目录 from Undefined.utils.paths import IMAGE_CACHE_DIR, ensure_dir @@ -106,6 +133,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: base_url = _get_xxapi_base_url() api_url = f"{base_url}{API_PATHS[picture_type]}" + # 获取图片 + success_count = 0 + fail_count = 0 + local_image_paths: list[str] = [] + for i in range(count): try: logger.info( @@ -186,14 +218,117 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if success_count == 0: return f"获取 {TYPE_NAMES[picture_type]} 图片失败,请稍后重试" - # 发送图片 + device_text = f"({device}端)" if picture_type == "acg" else "" + fourk_text = f"({fourk_type})" if picture_type == "random4kPic" else "" + + if delivery == "embed": + return await _deliver_embed( + local_image_paths, + success_count, + fail_count, + picture_type, + device_text, + fourk_text, + context, + ) + else: + return await _deliver_send( + local_image_paths, + success_count, + fail_count, + picture_type, + device_text, + fourk_text, + target_id, + message_type, + context, + ) + + +async def _deliver_embed( + local_image_paths: list[str], + success_count: int, + fail_count: int, + picture_type: str, + device_text: str, + fourk_text: str, + context: Dict[str, Any], +) -> str: + """注册图片到附件系统并返回 UID 标签。""" + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + if attachment_registry is None or not scope_key: + return "获取成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" + + uid_tags: list[str] = [] + register_fail = 0 + for image_path in local_image_paths: + try: + record = await attachment_registry.register_local_file( + scope_key, + image_path, + kind="image", + display_name=Path(image_path).name, + source_kind="get_picture", + source_ref=f"get_picture:{picture_type}", + ) + uid_tags.append(f'') + except Exception as exc: + logger.warning("注册图片到附件系统失败: %s", exc) + register_fail += 1 + + # 注册后删除缓存文件(register_local_file 已复制到 ATTACHMENT_CACHE_DIR) + try: + Path(image_path).unlink() + except Exception as e: + logger.warning(f"删除图片缓存文件失败: {e}") + + if not uid_tags: + return "获取成功,但注册到附件系统全部失败" + + success_cn = _CN_NUMS.get(len(uid_tags), str(len(uid_tags))) + result = f"已获取 {success_cn} 张 {TYPE_NAMES[picture_type]} 图片{device_text}{fourk_text}:\n" + result += "\n".join(uid_tags) + + total_fail = fail_count + register_fail + if total_fail > 0: + fail_cn = _CN_NUMS.get(total_fail, str(total_fail)) + result += f"\n(失败 {fail_cn} 张)" + + return result + + +async def _deliver_send( + local_image_paths: list[str], + success_count: int, + fail_count: int, + picture_type: str, + device_text: str, + fourk_text: str, + target_id: Any, + message_type: Any, + context: Dict[str, Any], +) -> str: + """直接发送图片到目标。""" + resolved_target_id, resolved_message_type, target_error = _resolve_send_target( + target_id, message_type, context + ) + if target_error or resolved_target_id is None or resolved_message_type is None: + return target_error or "获取成功,但缺少发送目标参数" + + send_image_callback = context.get("send_image_callback") + if not send_image_callback: + return "发送图片回调未设置" + + send_fail = 0 for idx, image_path in enumerate(local_image_paths, 1): try: logger.info( - f"正在发送第 {idx}/{success_count} 张图片到 {message_type} {target_id}" + f"正在发送第 {idx}/{success_count} 张图片到 {resolved_message_type} {resolved_target_id}" + ) + await send_image_callback( + resolved_target_id, resolved_message_type, image_path ) - logger.info(f"图片路径: {image_path}") - await send_image_callback(target_id, message_type, image_path) logger.info(f"图片 {idx} 发送成功") # 删除本地图片文件 @@ -207,29 +342,13 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: await asyncio.sleep(0.5) except Exception as e: logger.exception(f"发送图片失败: {e}") - fail_count += 1 + send_fail += 1 - # 返回结果 - device_text = f"({device}端)" if picture_type == "acg" else "" - fourk_text = f"({fourk_type})" if picture_type == "random4kPic" else "" + total_fail = fail_count + send_fail + success_cn = _CN_NUMS.get(success_count, str(success_count)) - # 中文数字映射 - cn_nums = { - 1: "一", - 2: "二", - 3: "三", - 4: "四", - 5: "五", - 6: "六", - 7: "七", - 8: "八", - 9: "九", - 10: "十", - } - success_cn = cn_nums.get(success_count, str(success_count)) - fail_cn = cn_nums.get(fail_count, str(fail_count)) - - if fail_count == 0: - return f"✅ 已成功发送 {success_cn} 张 {TYPE_NAMES[picture_type]} 图片{device_text}{fourk_text}到 {message_type} {target_id}" + if total_fail == 0: + return f"已成功发送 {success_cn} 张 {TYPE_NAMES[picture_type]} 图片{device_text}{fourk_text}到 {resolved_message_type} {resolved_target_id}" else: - return f"⚠️ 已发送 {success_cn} 张 {TYPE_NAMES[picture_type]} 图片{device_text}{fourk_text},失败 {fail_cn} 张" + fail_cn = _CN_NUMS.get(total_fail, str(total_fail)) + return f"已发送 {success_cn} 张 {TYPE_NAMES[picture_type]} 图片{device_text}{fourk_text},失败 {fail_cn} 张" diff --git a/src/Undefined/skills/toolsets/render/render_html/config.json b/src/Undefined/skills/toolsets/render/render_html/config.json index 1e72dd7e..53a1b17b 100644 --- a/src/Undefined/skills/toolsets/render/render_html/config.json +++ b/src/Undefined/skills/toolsets/render/render_html/config.json @@ -2,7 +2,7 @@ "type": "function", "function": { "name": "render_html", - "description": "将 HTML 内容渲染为图片并发送到指定目标(群聊或私聊)。支持完整的 HTML 文档,包括内联 CSS 和样式。", + "description": "将 HTML 内容渲染为图片。默认返回可嵌入回复的图片 UID(embed),也可直接发送到指定目标(send)。支持完整的 HTML 文档,包括内联 CSS 和样式。", "parameters": { "type": "object", "properties": { @@ -10,17 +10,22 @@ "type": "string", "description": "要渲染的 HTML 内容。必须是完整的 HTML 文档(包含 、、、 标签)。" }, + "delivery": { + "type": "string", + "description": "图片交付方式:embed 返回可插入回复的图片 UID;send 立即发送到目标", + "enum": ["embed", "send"] + }, "target_id": { "type": "integer", - "description": "目标 ID(群号或用户 QQ 号)" + "description": "目标 ID(群号或用户 QQ 号,仅 delivery=send 时需要,不提供则从当前会话推断)" }, "message_type": { "type": "string", - "description": "消息类型", + "description": "消息类型(仅 delivery=send 时需要,不提供则从当前会话推断)", "enum": ["group", "private"] } }, - "required": ["html_content", "target_id", "message_type"] + "required": ["html_content"] } } } \ No newline at end of file diff --git a/src/Undefined/skills/toolsets/render/render_html/handler.py b/src/Undefined/skills/toolsets/render/render_html/handler.py index f837a827..9341fb19 100644 --- a/src/Undefined/skills/toolsets/render/render_html/handler.py +++ b/src/Undefined/skills/toolsets/render/render_html/handler.py @@ -1,23 +1,47 @@ +from __future__ import annotations + from typing import Any, Dict import logging import uuid +from Undefined.attachments import scope_from_context + logger = logging.getLogger(__name__) +def _resolve_send_target( + target_id: Any, + message_type: Any, + context: Dict[str, Any], +) -> tuple[int | str | None, str | None, str | None]: + """从参数或 context 推断发送目标。""" + if target_id is not None and message_type is not None: + return target_id, message_type, None + request_type = str(context.get("request_type", "") or "").strip().lower() + if request_type == "group": + gid = context.get("group_id") + if gid is not None: + return gid, "group", None + if request_type == "private": + uid = context.get("user_id") + if uid is not None: + return uid, "private", None + return None, None, "渲染成功,但缺少发送目标参数" + + async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: - """将其余格式(如 Markdown)渲染为 HTML 格式""" + """将 HTML 内容渲染为图片""" html_content = args.get("html_content", "") + delivery = str(args.get("delivery", "embed") or "embed").strip().lower() target_id = args.get("target_id") message_type = args.get("message_type") if not html_content: return "HTML 内容不能为空" - if not target_id: - return "目标 ID 不能为空" - if not message_type: - return "消息类型不能为空" - if message_type not in ["group", "private"]: + if delivery not in {"embed", "send"}: + return f"delivery 无效:{delivery}。仅支持 embed 或 send" + + if delivery == "send" and message_type and message_type not in ("group", "private"): return "消息类型必须是 group 或 private" try: @@ -33,11 +57,45 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: await render_html_to_image(html_content, str(filepath)) + # 注册到附件系统 + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + record: Any = None + if attachment_registry is not None and scope_key: + try: + record = await attachment_registry.register_local_file( + scope_key, + filepath, + kind="image", + display_name=filename, + source_kind="rendered_image", + source_ref="render_html", + ) + except Exception as exc: + logger.warning("注册渲染图片到附件系统失败: %s", exc) + + if delivery == "embed": + cleanup_cache_dir(RENDER_CACHE_DIR) + if record is None: + return "渲染成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" + return f'' + + # delivery == "send" + resolved_target_id, resolved_message_type, target_error = _resolve_send_target( + target_id, message_type, context + ) + if target_error or resolved_target_id is None or resolved_message_type is None: + return target_error or "渲染成功,但缺少发送目标参数" + send_image_callback = context.get("send_image_callback") if send_image_callback: - await send_image_callback(target_id, message_type, str(filepath)) + await send_image_callback( + resolved_target_id, resolved_message_type, str(filepath) + ) cleanup_cache_dir(RENDER_CACHE_DIR) - return f"HTML 图片已渲染并发送到 {message_type} {target_id}" + return ( + f"HTML 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" + ) else: return "发送图片回调未设置" diff --git a/src/Undefined/skills/toolsets/render/render_latex/config.json b/src/Undefined/skills/toolsets/render/render_latex/config.json index 20eecf66..1d3c48ef 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/config.json +++ b/src/Undefined/skills/toolsets/render/render_latex/config.json @@ -2,7 +2,7 @@ "type": "function", "function": { "name": "render_latex", - "description": "将 LaTeX 文本渲染为图片并发送到指定目标(群聊或私聊)。支持完整的 LaTeX 语法(包含 \\begin 和 \\end)。", + "description": "将 LaTeX 文本渲染为图片。默认返回可嵌入回复的图片 UID(embed),也可直接发送到指定目标(send)。支持完整的 LaTeX 语法(包含 \\begin 和 \\end)。", "parameters": { "type": "object", "properties": { @@ -10,17 +10,22 @@ "type": "string", "description": "要渲染的 LaTeX 内容。必须是完整格式(包含 \\begin 和 \\end)。" }, + "delivery": { + "type": "string", + "description": "图片交付方式:embed 返回可插入回复的图片 UID;send 立即发送到目标", + "enum": ["embed", "send"] + }, "target_id": { "type": "integer", - "description": "目标 ID(群号或用户 QQ 号)" + "description": "目标 ID(群号或用户 QQ 号,仅 delivery=send 时需要,不提供则从当前会话推断)" }, "message_type": { "type": "string", - "description": "消息类型", + "description": "消息类型(仅 delivery=send 时需要,不提供则从当前会话推断)", "enum": ["group", "private"] } }, - "required": ["content", "target_id", "message_type"] + "required": ["content"] } } } diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index 68ca7012..26774d64 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -1,25 +1,49 @@ +from __future__ import annotations + from typing import Any, Dict import logging import uuid import matplotlib.pyplot as plt import matplotlib +from Undefined.attachments import scope_from_context + logger = logging.getLogger(__name__) +def _resolve_send_target( + target_id: Any, + message_type: Any, + context: Dict[str, Any], +) -> tuple[int | str | None, str | None, str | None]: + """从参数或 context 推断发送目标。""" + if target_id is not None and message_type is not None: + return target_id, message_type, None + request_type = str(context.get("request_type", "") or "").strip().lower() + if request_type == "group": + gid = context.get("group_id") + if gid is not None: + return gid, "group", None + if request_type == "private": + uid = context.get("user_id") + if uid is not None: + return uid, "private", None + return None, None, "渲染成功,但缺少发送目标参数" + + async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: - """渲染 LaTeX 数学公式为图片或文本描述""" + """渲染 LaTeX 数学公式为图片""" content = args.get("content", "") + delivery = str(args.get("delivery", "embed") or "embed").strip().lower() target_id = args.get("target_id") message_type = args.get("message_type") if not content: return "内容不能为空" - if not target_id: - return "目标 ID 不能为空" - if not message_type: - return "消息类型不能为空" - if message_type not in ["group", "private"]: + if delivery not in {"embed", "send"}: + return f"delivery 无效:{delivery}。仅支持 embed 或 send" + + if delivery == "send" and message_type and message_type not in ("group", "private"): return "消息类型必须是 group 或 private" try: @@ -50,26 +74,59 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: plt.savefig(filepath, dpi=150, bbox_inches="tight", pad_inches=0.1) plt.close(fig) - send_image_callback = context.get("send_image_callback") + # 注册到附件系统 + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + record: Any = None + if attachment_registry is not None and scope_key: + try: + record = await attachment_registry.register_local_file( + scope_key, + filepath, + kind="image", + display_name=filename, + source_kind="rendered_image", + source_ref="render_latex", + ) + except Exception as exc: + logger.warning("注册渲染图片到附件系统失败: %s", exc) + + if delivery == "embed": + cleanup_cache_dir(RENDER_CACHE_DIR) + if record is None: + return "渲染成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" + return f'' + + # delivery == "send" + resolved_target_id, resolved_message_type, target_error = _resolve_send_target( + target_id, message_type, context + ) + if target_error or resolved_target_id is None or resolved_message_type is None: + return target_error or "渲染成功,但缺少发送目标参数" + sender = context.get("sender") + send_image_callback = context.get("send_image_callback") if sender: from pathlib import Path - message = f"[CQ:image,file={Path(filepath).resolve().as_uri()}]" - - if message_type == "group": - await sender.send_group_message(int(target_id), message) - elif message_type == "private": - await sender.send_private_message(int(target_id), message) - + cq_message = f"[CQ:image,file={Path(filepath).resolve().as_uri()}]" + if resolved_message_type == "group": + await sender.send_group_message(int(resolved_target_id), cq_message) + elif resolved_message_type == "private": + await sender.send_private_message(int(resolved_target_id), cq_message) cleanup_cache_dir(RENDER_CACHE_DIR) - return f"LaTeX 图片已渲染并发送到 {message_type} {target_id}" - + return ( + f"LaTeX 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" + ) elif send_image_callback: - await send_image_callback(target_id, message_type, str(filepath)) + await send_image_callback( + resolved_target_id, resolved_message_type, str(filepath) + ) cleanup_cache_dir(RENDER_CACHE_DIR) - return f"LaTeX 图片已渲染并发送到 {message_type} {target_id}" + return ( + f"LaTeX 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" + ) else: return "发送图片回调未设置" diff --git a/src/Undefined/skills/toolsets/render/render_markdown/config.json b/src/Undefined/skills/toolsets/render/render_markdown/config.json index 060e3acd..9b768e7d 100644 --- a/src/Undefined/skills/toolsets/render/render_markdown/config.json +++ b/src/Undefined/skills/toolsets/render/render_markdown/config.json @@ -2,7 +2,7 @@ "type": "function", "function": { "name": "render_markdown", - "description": "将 Markdown 文本渲染为图片并发送到指定目标(群聊或私聊)。支持标准 Markdown 格式,包括标题、列表、代码块、表格等。", + "description": "将 Markdown 文本渲染为图片。默认返回可嵌入回复的图片 UID(embed),也可直接发送到指定目标(send)。支持标准 Markdown 格式,包括标题、列表、代码块、表格等。", "parameters": { "type": "object", "properties": { @@ -10,17 +10,22 @@ "type": "string", "description": "要渲染的 Markdown 内容。支持标准 Markdown 格式。" }, + "delivery": { + "type": "string", + "description": "图片交付方式:embed 返回可插入回复的图片 UID;send 立即发送到目标", + "enum": ["embed", "send"] + }, "target_id": { "type": "integer", - "description": "目标 ID(群号或用户 QQ 号)" + "description": "目标 ID(群号或用户 QQ 号,仅 delivery=send 时需要,不提供则从当前会话推断)" }, "message_type": { "type": "string", - "description": "消息类型", + "description": "消息类型(仅 delivery=send 时需要,不提供则从当前会话推断)", "enum": ["group", "private"] } }, - "required": ["content", "target_id", "message_type"] + "required": ["content"] } } } diff --git a/src/Undefined/skills/toolsets/render/render_markdown/handler.py b/src/Undefined/skills/toolsets/render/render_markdown/handler.py index 15834af1..cda52033 100644 --- a/src/Undefined/skills/toolsets/render/render_markdown/handler.py +++ b/src/Undefined/skills/toolsets/render/render_markdown/handler.py @@ -1,23 +1,47 @@ +from __future__ import annotations + from typing import Any, Dict import logging import uuid +from Undefined.attachments import scope_from_context + logger = logging.getLogger(__name__) +def _resolve_send_target( + target_id: Any, + message_type: Any, + context: Dict[str, Any], +) -> tuple[int | str | None, str | None, str | None]: + """从参数或 context 推断发送目标。""" + if target_id is not None and message_type is not None: + return target_id, message_type, None + request_type = str(context.get("request_type", "") or "").strip().lower() + if request_type == "group": + gid = context.get("group_id") + if gid is not None: + return gid, "group", None + if request_type == "private": + uid = context.get("user_id") + if uid is not None: + return uid, "private", None + return None, None, "渲染成功,但缺少发送目标参数" + + async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: """渲染指定的 Markdown 文本内容""" content = args.get("content", "") + delivery = str(args.get("delivery", "embed") or "embed").strip().lower() target_id = args.get("target_id") message_type = args.get("message_type") if not content: return "内容不能为空" - if not target_id: - return "目标 ID 不能为空" - if not message_type: - return "消息类型不能为空" - if message_type not in ["group", "private"]: + if delivery not in {"embed", "send"}: + return f"delivery 无效:{delivery}。仅支持 embed 或 send" + + if delivery == "send" and message_type and message_type not in ("group", "private"): return "消息类型必须是 group 或 private" try: @@ -40,26 +64,55 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.exception(f"Markdown 渲染失败: {e}") return "Markdown 渲染失败,请稍后重试" - send_image_callback = context.get("send_image_callback") + # 注册到附件系统 + attachment_registry = context.get("attachment_registry") + scope_key = scope_from_context(context) + record: Any = None + if attachment_registry is not None and scope_key: + try: + record = await attachment_registry.register_local_file( + scope_key, + filepath, + kind="image", + display_name=filename, + source_kind="rendered_image", + source_ref="render_markdown", + ) + except Exception as exc: + logger.warning("注册渲染图片到附件系统失败: %s", exc) + + if delivery == "embed": + cleanup_cache_dir(RENDER_CACHE_DIR) + if record is None: + return "渲染成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" + return f'' + + # delivery == "send" + resolved_target_id, resolved_message_type, target_error = _resolve_send_target( + target_id, message_type, context + ) + if target_error or resolved_target_id is None or resolved_message_type is None: + return target_error or "渲染成功,但缺少发送目标参数" + sender = context.get("sender") + send_image_callback = context.get("send_image_callback") if sender: from pathlib import Path - message = f"[CQ:image,file={Path(filepath).resolve().as_uri()}]" - - if message_type == "group": - await sender.send_group_message(int(target_id), message) - elif message_type == "private": - await sender.send_private_message(int(target_id), message) - + cq_message = f"[CQ:image,file={Path(filepath).resolve().as_uri()}]" + if resolved_message_type == "group": + await sender.send_group_message(int(resolved_target_id), cq_message) + elif resolved_message_type == "private": + await sender.send_private_message(int(resolved_target_id), cq_message) cleanup_cache_dir(RENDER_CACHE_DIR) - return f"Markdown 图片已渲染并发送到 {message_type} {target_id}" - + return f"Markdown 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" elif send_image_callback: - await send_image_callback(target_id, message_type, str(filepath)) + await send_image_callback( + resolved_target_id, resolved_message_type, str(filepath) + ) cleanup_cache_dir(RENDER_CACHE_DIR) - return f"Markdown 图片已渲染并发送到 {message_type} {target_id}" + return f"Markdown 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" else: return "发送图片回调未设置" From 0e118ab24c87e1f182bb0b51af3325c331ad999a Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Thu, 2 Apr 2026 14:39:25 +0800 Subject: [PATCH 12/21] fix(multimodal): harden attachment rendering and image edit upload --- src/Undefined/api/app.py | 11 +- src/Undefined/services/ai_coordinator.py | 3 +- .../tools/ai_draw_one/handler.py | 96 ++++++-------- tests/test_ai_coordinator_queue_routing.py | 38 ++++++ tests/test_ai_draw_one_handler.py | 55 ++++++++ tests/test_runtime_api_chat_stream.py | 124 ++++++++++++++++++ 6 files changed, 265 insertions(+), 62 deletions(-) create mode 100644 tests/test_runtime_api_chat_stream.py diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index fc913107..f1cad772 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -1429,14 +1429,11 @@ async def _capture_private_message(user_id: int, message: str) -> None: message_queue: asyncio.Queue[str] = asyncio.Queue() async def _capture_private_message_stream(user_id: int, message: str) -> None: + output_count = len(outputs) await _capture_private_message(user_id, message) - rendered = await render_message_with_pic_placeholders( - str(message or "").strip(), - registry=self._ctx.ai.attachment_registry, - scope_key=webui_scope_key, - strict=False, - ) - content = rendered.delivery_text.strip() + if len(outputs) <= output_count: + return + content = outputs[-1].strip() if content: await message_queue.put(content) diff --git a/src/Undefined/services/ai_coordinator.py b/src/Undefined/services/ai_coordinator.py index c73ba903..098d575f 100644 --- a/src/Undefined/services/ai_coordinator.py +++ b/src/Undefined/services/ai_coordinator.py @@ -196,8 +196,7 @@ async def handle_private_reply( f"\n{attachment_refs_to_xml(attachments)}" if attachments else "" ) full_question = f"""{prompt_prefix} - {escape_xml_text(text)} -{attachment_xml} + {escape_xml_text(text)}{attachment_xml} 【私聊消息】 diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py index ec93c4d6..441c671e 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py @@ -600,63 +600,53 @@ async def _call_openai_models_edit( if api_key: headers["Authorization"] = f"Bearer {api_key}" - file_handles: list[Any] = [] - files: list[tuple[str, tuple[str, Any, str]]] = [] - try: - for path in reference_image_paths: - file_handle = path.open("rb") - file_handles.append(file_handle) - files.append( + files: list[tuple[str, tuple[str, bytes, str]]] = [] + for path in reference_image_paths: + files.append( + ( + "image", ( - "image", - ( - path.name, - file_handle, - _guess_upload_media_type(path), - ), - ) + path.name, + path.read_bytes(), + _guess_upload_media_type(path), + ), ) + ) - try: - response = await request_with_retry( - "POST", - url, - data=form_data, - files=files, - headers=headers or None, - timeout=timeout_val, - context=context, - ) - except httpx.HTTPStatusError as exc: - message = _format_upstream_error_message(exc.response) - return f"参考图生图请求失败: HTTP {exc.response.status_code} {message}" - except httpx.TimeoutException: - return f"参考图生图请求超时({timeout_val:.0f}s)" - except httpx.RequestError as exc: - return f"参考图生图请求失败: {exc}" + try: + response = await request_with_retry( + "POST", + url, + data=form_data, + files=files, + headers=headers or None, + timeout=timeout_val, + context=context, + ) + except httpx.HTTPStatusError as exc: + message = _format_upstream_error_message(exc.response) + return f"参考图生图请求失败: HTTP {exc.response.status_code} {message}" + except httpx.TimeoutException: + return f"参考图生图请求超时({timeout_val:.0f}s)" + except httpx.RequestError as exc: + return f"参考图生图请求失败: {exc}" - try: - data = response.json() - except Exception: - return f"API 返回错误 (非JSON): {response.text[:100]}" - - generated_image = _parse_generated_image(data) - if generated_image is None: - logger.error(f"参考图生图 API 返回 (未找到图片内容): {data}") - return f"API 返回原文 (错误:未找到图片内容): {data}" - - logger.info(f"参考图生图 API 返回: {data}") - if generated_image.image_url: - logger.info(f"提取图片链接: {generated_image.image_url}") - elif generated_image.image_bytes is not None: - logger.info("提取图片字节: bytes=%s", len(generated_image.image_bytes)) - return generated_image - finally: - for handle in file_handles: - try: - handle.close() - except Exception: - pass + try: + data = response.json() + except Exception: + return f"API 返回错误 (非JSON): {response.text[:100]}" + + generated_image = _parse_generated_image(data) + if generated_image is None: + logger.error(f"参考图生图 API 返回 (未找到图片内容): {data}") + return f"API 返回原文 (错误:未找到图片内容): {data}" + + logger.info(f"参考图生图 API 返回: {data}") + if generated_image.image_url: + logger.info(f"提取图片链接: {generated_image.image_url}") + elif generated_image.image_bytes is not None: + logger.info("提取图片字节: bytes=%s", len(generated_image.image_bytes)) + return generated_image async def _download_and_send( diff --git a/tests/test_ai_coordinator_queue_routing.py b/tests/test_ai_coordinator_queue_routing.py index 7b3832ea..16266360 100644 --- a/tests/test_ai_coordinator_queue_routing.py +++ b/tests/test_ai_coordinator_queue_routing.py @@ -152,3 +152,41 @@ async def test_handle_private_reply_includes_trigger_message_id_in_full_question assert await_args is not None request_data = await_args.args[0] assert 'message_id="65432"' in request_data["full_question"] + + +@pytest.mark.asyncio +async def test_handle_private_reply_avoids_extra_blank_line_without_attachments() -> ( + None +): + coordinator: Any = object.__new__(AICoordinator) + queue_manager = SimpleNamespace( + add_superadmin_request=AsyncMock(), + add_private_request=AsyncMock(), + ) + coordinator.config = SimpleNamespace( + superadmin_qq=99999, + chat_model=SimpleNamespace(model_name="chat-model"), + ) + coordinator.security = SimpleNamespace( + detect_injection=AsyncMock(return_value=False) + ) + coordinator.history_manager = SimpleNamespace( + modify_last_private_message=AsyncMock() + ) + coordinator.queue_manager = queue_manager + coordinator.model_pool = SimpleNamespace( + select_chat_config=lambda chat_model, user_id: chat_model + ) + + await AICoordinator.handle_private_reply( + coordinator, + user_id=20001, + text="hello", + message_content=[], + sender_name="member", + ) + + await_args = cast(AsyncMock, queue_manager.add_private_request).await_args + assert await_args is not None + request_data = await_args.args[0] + assert "\n\n " not in request_data["full_question"] diff --git a/tests/test_ai_draw_one_handler.py b/tests/test_ai_draw_one_handler.py index 6b713454..f970e57c 100644 --- a/tests/test_ai_draw_one_handler.py +++ b/tests/test_ai_draw_one_handler.py @@ -515,6 +515,61 @@ async def _fake_request_with_retry( assert seen_request["data"]["background"] == "transparent" assert len(seen_request["files"]) == 1 assert seen_request["files"][0][0] == "image" + assert isinstance(seen_request["files"][0][1][1], bytes) + assert seen_request["files"][0][1][1] == _PNG_BYTES + + +@pytest.mark.asyncio +async def test_call_openai_models_edit_uses_retry_safe_byte_payloads( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + payload_base64 = base64.b64encode(_PNG_BYTES).decode("ascii") + reference_path = tmp_path / "reference.png" + reference_path.write_bytes(_PNG_BYTES) + + class _FakeResponse: + text = "" + + def json(self) -> dict[str, Any]: + return {"data": [{"base64": payload_base64}]} + + async def _fake_request_with_retry( + method: str, + url: str, + **kwargs: Any, + ) -> _FakeResponse: + assert method == "POST" + assert url == "https://edit.example.com/v1/images/edits" + files = kwargs["files"] + assert len(files) == 1 + filename, payload, content_type = files[0][1] + assert filename == "reference.png" + assert payload == _PNG_BYTES + assert isinstance(payload, bytes) + assert content_type == "image/png" + return _FakeResponse() + + monkeypatch.setattr(ai_draw_handler, "request_with_retry", _fake_request_with_retry) + + result = await ai_draw_handler._call_openai_models_edit( + prompt="use this as reference", + api_url="https://edit.example.com", + api_key="sk-edit", + model_name="grok-edit-1.0", + size="1024x1024", + quality="", + style="", + response_format="base64", + n=None, + timeout_val=30.0, + reference_image_paths=[reference_path], + extra_params={}, + context={}, + ) + + assert isinstance(result, ai_draw_handler._GeneratedImagePayload) + assert result.image_bytes == _PNG_BYTES @pytest.mark.asyncio diff --git a/tests/test_runtime_api_chat_stream.py b/tests/test_runtime_api_chat_stream.py new file mode 100644 index 00000000..520ba617 --- /dev/null +++ b/tests/test_runtime_api_chat_stream.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock + +import pytest +from aiohttp import web + +from Undefined.api import RuntimeAPIContext, RuntimeAPIServer +from Undefined.api import app as runtime_api_app + + +class _DummyTransport: + def is_closing(self) -> bool: + return False + + +class _DummyRequest(SimpleNamespace): + async def json(self) -> dict[str, object]: + return {"message": "hello", "stream": True} + + +class _DummyStreamResponse: + def __init__( + self, + *, + status: int, + reason: str, + headers: dict[str, str], + ) -> None: + self.status = status + self.reason = reason + self.headers = dict(headers) + self.writes: list[bytes] = [] + self.eof_written = False + + async def prepare(self, request: web.Request) -> _DummyStreamResponse: + _ = request + return self + + async def write(self, data: bytes) -> None: + self.writes.append(data) + + async def write_eof(self) -> None: + self.eof_written = True + + +@pytest.mark.asyncio +async def test_runtime_chat_stream_renders_each_message_once( + monkeypatch: pytest.MonkeyPatch, +) -> None: + render_calls: list[str] = [] + + async def _fake_render_message_with_pic_placeholders( + message: str, + *, + registry: Any, + scope_key: str, + strict: bool, + ) -> Any: + _ = registry, scope_key, strict + render_calls.append(message) + return SimpleNamespace( + delivery_text="rendered stream reply", + history_text="rendered history reply", + attachments=[], + ) + + context = RuntimeAPIContext( + config_getter=lambda: SimpleNamespace( + api=SimpleNamespace( + enabled=True, + host="127.0.0.1", + port=8788, + auth_key="changeme", + openapi_enabled=True, + ), + superadmin_qq=10001, + bot_qq=20002, + ), + onebot=SimpleNamespace(connection_status=lambda: {}), + ai=SimpleNamespace( + attachment_registry=object(), + memory_storage=SimpleNamespace(count=lambda: 0), + ), + command_dispatcher=SimpleNamespace(), + queue_manager=SimpleNamespace(snapshot=lambda: {}), + history_manager=SimpleNamespace(add_private_message=AsyncMock()), + ) + server = RuntimeAPIServer(context, host="127.0.0.1", port=8788) + + async def _fake_run_webui_chat(*, text: str, send_output: Any) -> str: + assert text == "hello" + await send_output(42, "bot reply with ") + return "chat" + + monkeypatch.setattr( + runtime_api_app, + "render_message_with_pic_placeholders", + _fake_render_message_with_pic_placeholders, + ) + monkeypatch.setattr(web, "StreamResponse", _DummyStreamResponse) + monkeypatch.setattr(server, "_run_webui_chat", _fake_run_webui_chat) + + request = cast( + web.Request, + cast( + Any, + _DummyRequest( + transport=_DummyTransport(), + ), + ), + ) + + response = await server._chat_handler(request) + + assert isinstance(response, _DummyStreamResponse) + assert render_calls == ["bot reply with "] + payload = b"".join(response.writes).decode("utf-8") + assert payload.count("event: message") == 1 + assert "rendered stream reply" in payload + assert "event: done" in payload + assert response.eof_written is True From 735677861aaa45e43d9c4635a759f440b2f876a1 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Thu, 2 Apr 2026 21:54:51 +0800 Subject: [PATCH 13/21] fix(ai): avoid blocking image edit file io Add async binary reads for reference images and cover related prompt-format regressions. Co-authored-by: GPT-5.4 xhigh --- .../tools/ai_draw_one/handler.py | 4 +- src/Undefined/utils/io.py | 14 ++++ tests/test_ai_draw_one_handler.py | 13 ++++ tests/test_runtime_api_chat_stream.py | 66 +++++++++++++++++++ 4 files changed, 96 insertions(+), 1 deletion(-) diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py index 441c671e..d5e07b1b 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py @@ -24,6 +24,7 @@ from Undefined.ai.parsing import extract_choices_content from Undefined.skills.http_client import request_with_retry from Undefined.skills.http_config import get_request_timeout, get_xingzhige_url +from Undefined.utils.io import read_bytes from Undefined.utils.resources import read_text_resource logger = logging.getLogger(__name__) @@ -602,12 +603,13 @@ async def _call_openai_models_edit( files: list[tuple[str, tuple[str, bytes, str]]] = [] for path in reference_image_paths: + file_bytes = await read_bytes(path) files.append( ( "image", ( path.name, - path.read_bytes(), + file_bytes, _guess_upload_media_type(path), ), ) diff --git a/src/Undefined/utils/io.py b/src/Undefined/utils/io.py index c4b46bef..738a681a 100644 --- a/src/Undefined/utils/io.py +++ b/src/Undefined/utils/io.py @@ -231,6 +231,14 @@ def _read_text_sync(target: Path, use_lock: bool) -> str | None: return target.read_text(encoding="utf-8") +def _read_bytes_sync(target: Path, use_lock: bool) -> bytes: + if use_lock: + lock_path = target.with_name(f"{target.name}.lock") + with FileLock(lock_path, shared=True): + return target.read_bytes() + return target.read_bytes() + + async def write_text( file_path: str | Path, content: str, use_lock: bool = True ) -> None: @@ -243,3 +251,9 @@ async def read_text(file_path: str | Path, use_lock: bool = False) -> str | None """异步读取文本文件""" target = Path(file_path) return await asyncio.to_thread(_read_text_sync, target, use_lock) + + +async def read_bytes(file_path: str | Path, use_lock: bool = False) -> bytes: + """异步读取二进制文件""" + target = Path(file_path) + return await asyncio.to_thread(_read_bytes_sync, target, use_lock) diff --git a/tests/test_ai_draw_one_handler.py b/tests/test_ai_draw_one_handler.py index f970e57c..27b4c991 100644 --- a/tests/test_ai_draw_one_handler.py +++ b/tests/test_ai_draw_one_handler.py @@ -527,6 +527,16 @@ async def test_call_openai_models_edit_uses_retry_safe_byte_payloads( payload_base64 = base64.b64encode(_PNG_BYTES).decode("ascii") reference_path = tmp_path / "reference.png" reference_path.write_bytes(_PNG_BYTES) + observed_read_paths: list[tuple[Path, bool]] = [] + + async def _fake_read_bytes( + file_path: str | Path, use_lock: bool = False + ) -> bytes: + observed_read_paths.append((Path(file_path), use_lock)) + return _PNG_BYTES + + def _unexpected_sync_read_bytes(_self: Path) -> bytes: + raise AssertionError("should use async read_bytes helper instead of Path.read_bytes") class _FakeResponse: text = "" @@ -550,6 +560,8 @@ async def _fake_request_with_retry( assert content_type == "image/png" return _FakeResponse() + monkeypatch.setattr(ai_draw_handler, "read_bytes", _fake_read_bytes) + monkeypatch.setattr(type(reference_path), "read_bytes", _unexpected_sync_read_bytes) monkeypatch.setattr(ai_draw_handler, "request_with_retry", _fake_request_with_retry) result = await ai_draw_handler._call_openai_models_edit( @@ -570,6 +582,7 @@ async def _fake_request_with_retry( assert isinstance(result, ai_draw_handler._GeneratedImagePayload) assert result.image_bytes == _PNG_BYTES + assert observed_read_paths == [(reference_path, False)] @pytest.mark.asyncio diff --git a/tests/test_runtime_api_chat_stream.py b/tests/test_runtime_api_chat_stream.py index 520ba617..693bd5ea 100644 --- a/tests/test_runtime_api_chat_stream.py +++ b/tests/test_runtime_api_chat_stream.py @@ -122,3 +122,69 @@ async def _fake_run_webui_chat(*, text: str, send_output: Any) -> str: assert "rendered stream reply" in payload assert "event: done" in payload assert response.eof_written is True + + +@pytest.mark.asyncio +async def test_run_webui_chat_avoids_extra_blank_line_without_attachments( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured_prompt: dict[str, str] = {} + + async def _fake_register_message_attachments(**kwargs: Any) -> Any: + _ = kwargs + return SimpleNamespace(normalized_text="hello", attachments=[]) + + async def _fake_ask(full_question: str, **kwargs: Any) -> str: + _ = kwargs + captured_prompt["full_question"] = full_question + return "" + + context = RuntimeAPIContext( + config_getter=lambda: SimpleNamespace( + api=SimpleNamespace( + enabled=True, + host="127.0.0.1", + port=8788, + auth_key="changeme", + openapi_enabled=True, + ), + superadmin_qq=10001, + bot_qq=20002, + ), + onebot=SimpleNamespace( + connection_status=lambda: {}, + get_image=AsyncMock(), + get_forward_msg=AsyncMock(), + ), + ai=SimpleNamespace( + attachment_registry=object(), + ask=_fake_ask, + memory_storage=SimpleNamespace(count=lambda: 0), + runtime_config=SimpleNamespace(), + ), + command_dispatcher=SimpleNamespace( + parse_command=lambda _text: None, + dispatch_private=AsyncMock(), + ), + queue_manager=SimpleNamespace(snapshot=lambda: {}), + history_manager=SimpleNamespace(add_private_message=AsyncMock()), + ) + server = RuntimeAPIServer(context, host="127.0.0.1", port=8788) + + monkeypatch.setattr( + runtime_api_app, + "register_message_attachments", + _fake_register_message_attachments, + ) + monkeypatch.setattr(runtime_api_app, "collect_context_resources", lambda _vars: {}) + + sent_messages: list[tuple[int, str]] = [] + + async def _send_output(user_id: int, message: str) -> None: + sent_messages.append((user_id, message)) + + result = await server._run_webui_chat(text="hello", send_output=_send_output) + + assert result == "chat" + assert sent_messages == [] + assert "\n\n " not in captured_prompt["full_question"] From ee03214e94a8c4db9540b303310d301e0a17ba6f Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Thu, 2 Apr 2026 21:56:25 +0800 Subject: [PATCH 14/21] style(tests): format ai draw one handler test Apply Ruff formatting to the new async read regression test. Co-authored-by: GPT-5.4 xhigh --- tests/test_ai_draw_one_handler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_ai_draw_one_handler.py b/tests/test_ai_draw_one_handler.py index 27b4c991..49f76634 100644 --- a/tests/test_ai_draw_one_handler.py +++ b/tests/test_ai_draw_one_handler.py @@ -529,14 +529,14 @@ async def test_call_openai_models_edit_uses_retry_safe_byte_payloads( reference_path.write_bytes(_PNG_BYTES) observed_read_paths: list[tuple[Path, bool]] = [] - async def _fake_read_bytes( - file_path: str | Path, use_lock: bool = False - ) -> bytes: + async def _fake_read_bytes(file_path: str | Path, use_lock: bool = False) -> bytes: observed_read_paths.append((Path(file_path), use_lock)) return _PNG_BYTES def _unexpected_sync_read_bytes(_self: Path) -> bytes: - raise AssertionError("should use async read_bytes helper instead of Path.read_bytes") + raise AssertionError( + "should use async read_bytes helper instead of Path.read_bytes" + ) class _FakeResponse: text = "" From 6506b53f84389cd29f8769c36412c14199421336 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Thu, 2 Apr 2026 22:22:12 +0800 Subject: [PATCH 15/21] chore(version): bump version to 3.2.8 Update package versions, lockfiles, and changelog for the 3.2.8 release. Co-authored-by: GPT-5.4 xhigh --- CHANGELOG.md | 19 +++++++++++++++++++ apps/undefined-console/package-lock.json | 4 ++-- apps/undefined-console/package.json | 2 +- apps/undefined-console/src-tauri/Cargo.lock | 2 +- apps/undefined-console/src-tauri/Cargo.toml | 2 +- .../src-tauri/tauri.conf.json | 2 +- pyproject.toml | 2 +- src/Undefined/__init__.py | 2 +- uv.lock | 2 +- 9 files changed, 28 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 177724e1..7986fb6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,22 @@ +## v3.2.8 多模态附件与参考图生图 + +围绕多模态消息链路进行了较大增强,引入统一的附件 UID 注册与 `` 图文混排机制,打通群聊、私聊、WebUI 和合并转发中的图片/文件上下文。同步为生图工具补齐参考图生图、独立图片编辑模型配置与提示词审核能力,并完善 Runtime API、认知侧写和动态技能加载的稳定性。 + +- 新增统一附件注册系统,群聊、私聊、WebUI 会话和合并转发中的图片/文件会登记为内部附件 UID,并写入历史记录与提示词上下文。 +- 系统提示词与发送链路支持 `` 图文混排,AI 可直接在回复中内嵌当前会话内的图片 UID。 +- 新增 `fetch_image_uid` 工具,支持将远程图片 URL 拉取并注册为当前会话可复用的图片 UID。 +- `get_picture`、`render_html`、`render_markdown`、`render_latex` 默认改为 `embed` 返回可嵌入图片 UID,同时保留 `send` 直接发送模式。 +- `file_analysis_agent` 及相关文件/图片链路已适配附件 UID,优先使用内部 `pic_*` / `file_*` 标识,也继续兼容 URL 与 legacy `file_id`。 +- `ai_draw_one` 支持基于 `reference_image_uids` 的参考图生图,新增 `[models.image_edit]` 配置节,并接入 OpenAI 兼容的 `/v1/images/edits` 接口。 +- 新增基于 Agent 模型的生图提示词审核能力,并补充独立审核提示模板。 +- 优化图片生成请求体验,支持 base64 返回、保留显式尺寸、锁定模型参数,并更清晰地暴露上游错误。 +- 修复附件渲染、图片编辑上传与阻塞式文件读取问题,增强多模态链路稳定性。 +- Runtime API 与 WebUI 现在会注册聊天附件、渲染图片 UID,并支持本地 `file://` 图片预览;同时取消 `_agent` 工具调用的固定超时上限。 +- 认知服务会在用户昵称或群名变化时自动刷新 profile 展示名与向量索引,减少侧写名称陈旧问题。 +- 修复动态技能模块加载与超时包装问题,并补充相关测试覆盖;同步整理 README,移除顶部头图。 + +--- + ## v3.2.7 arXiv 工具集与运行时变更感知 新增 arXiv 论文搜索与提取工具集,以及运行时 CHANGELOG 查询能力。重构了生图工具支持 OpenAI 兼容接口,引入 grok_search 联网搜索工具,并让 AI 在系统提示词中感知自身模型配置信息。同步修复了多项稳定性问题与 CI 效率优化。 diff --git a/apps/undefined-console/package-lock.json b/apps/undefined-console/package-lock.json index 6001addb..6fd8dd29 100644 --- a/apps/undefined-console/package-lock.json +++ b/apps/undefined-console/package-lock.json @@ -1,12 +1,12 @@ { "name": "undefined-console", - "version": "3.2.7", + "version": "3.2.8", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "undefined-console", - "version": "3.2.7", + "version": "3.2.8", "dependencies": { "@tauri-apps/api": "^2.3.0", "@tauri-apps/plugin-http": "^2.3.0" diff --git a/apps/undefined-console/package.json b/apps/undefined-console/package.json index 6060cbe3..d99ef216 100644 --- a/apps/undefined-console/package.json +++ b/apps/undefined-console/package.json @@ -1,7 +1,7 @@ { "name": "undefined-console", "private": true, - "version": "3.2.7", + "version": "3.2.8", "type": "module", "scripts": { "tauri": "tauri", diff --git a/apps/undefined-console/src-tauri/Cargo.lock b/apps/undefined-console/src-tauri/Cargo.lock index d16730ab..5a512d97 100644 --- a/apps/undefined-console/src-tauri/Cargo.lock +++ b/apps/undefined-console/src-tauri/Cargo.lock @@ -4063,7 +4063,7 @@ checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "undefined_console" -version = "3.2.7" +version = "3.2.8" dependencies = [ "serde", "serde_json", diff --git a/apps/undefined-console/src-tauri/Cargo.toml b/apps/undefined-console/src-tauri/Cargo.toml index 2df37efc..0b91ad48 100644 --- a/apps/undefined-console/src-tauri/Cargo.toml +++ b/apps/undefined-console/src-tauri/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "undefined_console" -version = "3.2.7" +version = "3.2.8" description = "Undefined cross-platform management console" authors = ["Undefined contributors"] license = "MIT" diff --git a/apps/undefined-console/src-tauri/tauri.conf.json b/apps/undefined-console/src-tauri/tauri.conf.json index f29fd6dc..cc743feb 100644 --- a/apps/undefined-console/src-tauri/tauri.conf.json +++ b/apps/undefined-console/src-tauri/tauri.conf.json @@ -1,7 +1,7 @@ { "$schema": "https://schema.tauri.app/config/2", "productName": "Undefined Console", - "version": "3.2.7", + "version": "3.2.8", "identifier": "com.undefined.console", "build": { "beforeDevCommand": "npm run dev", diff --git a/pyproject.toml b/pyproject.toml index ebf60f9e..45fb893c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "Undefined-bot" -version = "3.2.7" +version = "3.2.8" description = "QQ bot platform with cognitive memory architecture and multi-agent Skills, via OneBot V11." readme = "README.md" authors = [ diff --git a/src/Undefined/__init__.py b/src/Undefined/__init__.py index 6c8d6db7..b90dcf3e 100644 --- a/src/Undefined/__init__.py +++ b/src/Undefined/__init__.py @@ -1,3 +1,3 @@ """Undefined - A high-performance, highly scalable QQ group and private chat robot based on a self-developed architecture.""" -__version__ = "3.2.7" +__version__ = "3.2.8" diff --git a/uv.lock b/uv.lock index 0955f37b..407fddf9 100644 --- a/uv.lock +++ b/uv.lock @@ -4638,7 +4638,7 @@ wheels = [ [[package]] name = "undefined-bot" -version = "3.2.7" +version = "3.2.8" source = { editable = "." } dependencies = [ { name = "aiofiles" }, From 0ff699bf159100a68b483deea8cad885d30c5ff9 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Thu, 2 Apr 2026 22:52:35 +0800 Subject: [PATCH 16/21] fix(runtime): harden profile sync and attachment loading Avoid syncing placeholder poke names into cognitive profiles, align WebUI tool timeout handling with actual agent schemas, and load attachment registry asynchronously. Co-authored-by: GPT-5.4 xhigh --- src/Undefined/attachments.py | 55 +++++++++++++++---- src/Undefined/handlers.py | 19 ++++--- src/Undefined/main.py | 1 + src/Undefined/webui/routes/_runtime.py | 39 +++++++++++++- tests/test_attachments.py | 46 ++++++++++++++++ tests/test_handlers_poke_history.py | 73 +++++++++++++++++++++++++- tests/test_queue_timeout_budgets.py | 15 +++++- 7 files changed, 226 insertions(+), 22 deletions(-) diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py index 00d5e7bd..4cbfd93d 100644 --- a/src/Undefined/attachments.py +++ b/src/Undefined/attachments.py @@ -385,18 +385,13 @@ def __init__( self._http_client = http_client self._lock = asyncio.Lock() self._records: dict[str, AttachmentRecord] = {} + self._loaded = False + self._load_task: asyncio.Task[None] | None = None self._load_from_disk() - def _load_from_disk(self) -> None: - if not self._registry_path.exists(): - return - try: - raw = json.loads(self._registry_path.read_text(encoding="utf-8")) - except Exception as exc: - logger.warning("[AttachmentRegistry] 读取失败: %s", exc) - return + def _load_records_from_payload(self, raw: Any) -> dict[str, AttachmentRecord]: if not isinstance(raw, dict): - return + return {} loaded: dict[str, AttachmentRecord] = {} for uid, item in raw.items(): if not isinstance(item, dict): @@ -421,7 +416,46 @@ def _load_from_disk(self) -> None: ) except Exception: continue - self._records = loaded + return loaded + + def _load_from_disk(self) -> None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + self._load_from_disk_sync() + return + self._load_task = loop.create_task(self._load_from_disk_async()) + + def _load_from_disk_sync(self) -> None: + if not self._registry_path.exists(): + self._loaded = True + return + try: + raw = json.loads(self._registry_path.read_text(encoding="utf-8")) + except Exception as exc: + logger.warning("[AttachmentRegistry] 读取失败: %s", exc) + self._loaded = True + return + self._records = self._load_records_from_payload(raw) + self._loaded = True + + async def _load_from_disk_async(self) -> None: + try: + raw = await io.read_json(self._registry_path, use_lock=False) + except Exception as exc: + logger.warning("[AttachmentRegistry] 读取失败: %s", exc) + self._loaded = True + return + self._records = self._load_records_from_payload(raw) + self._loaded = True + + async def load(self) -> None: + """等待注册表完成初始加载。""" + if self._loaded: + return + if self._load_task is None: + self._load_task = asyncio.create_task(self._load_from_disk_async()) + await self._load_task async def _persist(self) -> None: payload = {uid: asdict(record) for uid, record in self._records.items()} @@ -464,6 +498,7 @@ async def register_bytes( source_ref: str = "", mime_type: str | None = None, ) -> AttachmentRecord: + await self.load() normalized_kind = _media_kind_from_value(kind) normalized_media_type = ( "image" if normalized_kind == "image" else normalized_kind diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index e5875929..7d0a73a8 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -648,15 +648,16 @@ async def _record_private_poke_history( exc, ) - display_name = sender_nickname or user_name or f"QQ{user_id}" + resolved_sender_name = (sender_nickname or user_name).strip() + display_name = resolved_sender_name or f"QQ{user_id}" normalized_user_name = user_name or display_name poke_text = _format_poke_history_text(display_name, user_id) - if display_name.strip() and self._can_refresh_profile_display_names(): + if resolved_sender_name and self._can_refresh_profile_display_names(): self._spawn_background_task( f"profile_name_refresh_private_poke:{user_id}", self._refresh_profile_display_names( sender_id=user_id, - sender_name=display_name, + sender_name=resolved_sender_name, ), ) @@ -725,19 +726,21 @@ async def _record_group_poke_history( exc, ) - display_name = sender_card or sender_nickname or f"QQ{sender_id}" + resolved_sender_name = (sender_card or sender_nickname).strip() + resolved_group_name = group_name.strip() + display_name = resolved_sender_name or f"QQ{sender_id}" poke_text = _format_poke_history_text(display_name, sender_id) - normalized_group_name = group_name or f"群{group_id}" - if (display_name.strip() or normalized_group_name.strip()) and ( + normalized_group_name = resolved_group_name or f"群{group_id}" + if (resolved_sender_name or resolved_group_name) and ( self._can_refresh_profile_display_names() ): self._spawn_background_task( f"profile_name_refresh_group_poke:{group_id}:{sender_id}", self._refresh_profile_display_names( sender_id=sender_id, - sender_name=display_name, + sender_name=resolved_sender_name, group_id=group_id, - group_name=normalized_group_name, + group_name=resolved_group_name, ), ) diff --git a/src/Undefined/main.py b/src/Undefined/main.py index ce21cd94..6980e67c 100644 --- a/src/Undefined/main.py +++ b/src/Undefined/main.py @@ -207,6 +207,7 @@ async def main() -> None: bot_qq=config.bot_qq, runtime_config=config, ) + await ai.attachment_registry.load() faq_storage = FAQStorage() from Undefined.knowledge import RetrievalRuntime diff --git a/src/Undefined/webui/routes/_runtime.py b/src/Undefined/webui/routes/_runtime.py index 2d2043a4..fa740566 100644 --- a/src/Undefined/webui/routes/_runtime.py +++ b/src/Undefined/webui/routes/_runtime.py @@ -36,9 +36,46 @@ def _chat_proxy_timeout_seconds() -> float: return compute_queued_llm_timeout_seconds(cfg, cfg.chat_model) +def _load_local_function_names(*roots: Path) -> set[str]: + names: set[str] = set() + for root in roots: + if not root.exists(): + continue + for config_path in root.rglob("config.json"): + try: + relative_parts = config_path.relative_to(root).parts + except ValueError: + relative_parts = config_path.parts + if any(part.startswith("_") for part in relative_parts): + continue + try: + raw = json.loads(config_path.read_text(encoding="utf-8")) + except Exception: + continue + function = raw.get("function", {}) + if not isinstance(function, dict): + continue + name = str(function.get("name", "") or "").strip() + if name: + names.add(name) + return names + + +def _get_local_agent_tool_names() -> set[str]: + skills_root = Path(__file__).resolve().parents[2] / "skills" + return _load_local_function_names(skills_root / "agents") + + +def _get_local_tool_names() -> set[str]: + skills_root = Path(__file__).resolve().parents[2] / "skills" + return _load_local_function_names(skills_root / "tools", skills_root / "toolsets") + + def _tool_invoke_proxy_timeout_seconds(tool_name: str) -> float | None: normalized_name = str(tool_name or "").strip() - if normalized_name.endswith("_agent"): + if normalized_name in _get_local_agent_tool_names(): + return None + if normalized_name not in _get_local_tool_names(): return None cfg = get_config(strict=False) diff --git a/tests/test_attachments.py b/tests/test_attachments.py index 1c7874b1..d2377d15 100644 --- a/tests/test_attachments.py +++ b/tests/test_attachments.py @@ -2,6 +2,7 @@ import base64 from pathlib import Path +from typing import Any import pytest @@ -10,6 +11,7 @@ register_message_attachments, render_message_with_pic_placeholders, ) +from Undefined.utils import io as io_utils _PNG_BYTES = ( @@ -40,10 +42,54 @@ async def test_attachment_registry_persists_and_respects_scope( ) reloaded = AttachmentRegistry(registry_path=registry_path, cache_dir=cache_dir) + await reloaded.load() assert reloaded.resolve(record.uid, "group:10001") is not None assert reloaded.resolve(record.uid, "group:10002") is None +@pytest.mark.asyncio +async def test_attachment_registry_load_uses_async_read_json( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + registry_path = tmp_path / "attachment_registry.json" + cache_dir = tmp_path / "attachments" + seen_calls: list[tuple[Path, bool]] = [] + payload = { + "pic_async123": { + "uid": "pic_async123", + "scope_key": "group:10001", + "kind": "image", + "media_type": "image", + "display_name": "cat.png", + "source_kind": "test", + "source_ref": "test", + "local_path": str(cache_dir / "pic_async123.png"), + "mime_type": "image/png", + "sha256": "digest", + "created_at": "2026-04-02T00:00:00", + } + } + + async def _fake_read_json(file_path: str | Path, use_lock: bool = False) -> Any: + seen_calls.append((Path(file_path), use_lock)) + return payload + + def _unexpected_sync_read_text(_self: Path, *_args: Any, **_kwargs: Any) -> str: + raise AssertionError( + "should use async read_json helper instead of Path.read_text" + ) + + monkeypatch.setattr(io_utils, "read_json", _fake_read_json) + monkeypatch.setattr(Path, "read_text", _unexpected_sync_read_text) + + registry = AttachmentRegistry(registry_path=registry_path, cache_dir=cache_dir) + await registry.load() + + assert seen_calls == [(registry_path, False)] + assert registry.resolve("pic_async123", "group:10001") is not None + + @pytest.mark.asyncio async def test_register_message_attachments_normalizes_webui_base64_image( tmp_path: Path, diff --git a/tests/test_handlers_poke_history.py b/tests/test_handlers_poke_history.py index d374a8e4..8dc81b1f 100644 --- a/tests/test_handlers_poke_history.py +++ b/tests/test_handlers_poke_history.py @@ -1,5 +1,4 @@ """MessageHandler 拍一拍历史记录测试""" - from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock @@ -26,10 +25,15 @@ def _build_handler() -> Any: handle_private_reply=AsyncMock(), handle_auto_reply=AsyncMock(), ) + handler.ai = SimpleNamespace(_cognitive_service=None) handler.onebot = SimpleNamespace( get_stranger_info=AsyncMock(return_value={"nickname": "测试用户"}), + get_group_member_info=AsyncMock( + return_value={"card": "群名片", "nickname": "群昵称"} + ), get_group_info=AsyncMock(return_value={"group_name": "测试群"}), ) + handler._background_tasks = set() return handler @@ -117,3 +121,70 @@ async def test_group_poke_writes_history_and_triggers_reply() -> None: handler.history_manager.add_private_message.assert_not_called() handler.ai_coordinator.handle_private_reply.assert_not_called() + + +@pytest.mark.asyncio +async def test_private_poke_skips_profile_refresh_for_placeholder_name() -> None: + handler = _build_handler() + handler.ai = SimpleNamespace(_cognitive_service=SimpleNamespace(enabled=True)) + handler.onebot = SimpleNamespace(get_stranger_info=AsyncMock(return_value={})) + handler._refresh_profile_display_names = AsyncMock() + scheduled: list[tuple[str, Any]] = [] + + def _fake_spawn(name: str, coroutine: Any) -> None: + scheduled.append((name, coroutine)) + + handler._spawn_background_task = _fake_spawn + event = { + "post_type": "notice", + "notice_type": "poke", + "target_id": 10000, + "group_id": 0, + "user_id": 20001, + "sender": {"user_id": 20001}, + } + + await handler.handle_message(event) + + assert scheduled == [] + handler._refresh_profile_display_names.assert_not_awaited() + private_history_call = handler.history_manager.add_private_message.call_args + assert private_history_call is not None + assert private_history_call.kwargs["display_name"] == "QQ20001" + + +@pytest.mark.asyncio +async def test_group_poke_skips_profile_refresh_for_placeholder_names() -> None: + handler = _build_handler() + handler.ai = SimpleNamespace(_cognitive_service=SimpleNamespace(enabled=True)) + handler.onebot = SimpleNamespace( + get_group_member_info=AsyncMock(return_value={}), + get_group_info=AsyncMock(return_value={}), + ) + handler._refresh_profile_display_names = AsyncMock() + scheduled: list[tuple[str, Any]] = [] + + def _fake_spawn(name: str, coroutine: Any) -> None: + scheduled.append((name, coroutine)) + + handler._spawn_background_task = _fake_spawn + event = { + "post_type": "notice", + "notice_type": "poke", + "target_id": 10000, + "group_id": 30001, + "user_id": 20001, + "sender": {"user_id": 20001}, + } + + await handler.handle_message(event) + + assert scheduled == [] + handler._refresh_profile_display_names.assert_not_awaited() + group_history_call = handler.history_manager.add_group_message.call_args + assert group_history_call is not None + assert group_history_call.kwargs["group_name"] == "群30001" + assert ( + group_history_call.kwargs["text_content"] + == "QQ20001(暱称)[20001(QQ号)] 拍了拍你。" + ) diff --git a/tests/test_queue_timeout_budgets.py b/tests/test_queue_timeout_budgets.py index 32b16091..245264ab 100644 --- a/tests/test_queue_timeout_budgets.py +++ b/tests/test_queue_timeout_budgets.py @@ -134,16 +134,27 @@ def test_chat_proxy_timeout_uses_queue_budget(monkeypatch: pytest.MonkeyPatch) - ) -def test_tool_invoke_proxy_timeout_skips_agents( +def test_tool_invoke_proxy_timeout_uses_local_schema_sets( monkeypatch: pytest.MonkeyPatch, ) -> None: cfg = SimpleNamespace(api=SimpleNamespace(tool_invoke_timeout=120)) monkeypatch.setattr( runtime_routes, "get_config", lambda strict=False: cast(Any, cfg) ) + monkeypatch.setattr( + runtime_routes, + "_get_local_agent_tool_names", + lambda: {"custom_agent_runner"}, + ) + monkeypatch.setattr( + runtime_routes, + "_get_local_tool_names", + lambda: {"messages.send_message"}, + ) - assert runtime_routes._tool_invoke_proxy_timeout_seconds("web_agent") is None + assert runtime_routes._tool_invoke_proxy_timeout_seconds("custom_agent_runner") is None assert ( runtime_routes._tool_invoke_proxy_timeout_seconds("messages.send_message") == 180.0 ) + assert runtime_routes._tool_invoke_proxy_timeout_seconds("unknown_tool") is None From 8bb5ea35295fb82838ce52dce11bbc5b79d93d5e Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Thu, 2 Apr 2026 22:54:02 +0800 Subject: [PATCH 17/21] style(tests): format runtime review regression tests Apply Ruff formatting to the new review regression tests. Co-authored-by: GPT-5.4 xhigh --- tests/test_handlers_poke_history.py | 1 + tests/test_queue_timeout_budgets.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_handlers_poke_history.py b/tests/test_handlers_poke_history.py index 8dc81b1f..ede69a23 100644 --- a/tests/test_handlers_poke_history.py +++ b/tests/test_handlers_poke_history.py @@ -1,4 +1,5 @@ """MessageHandler 拍一拍历史记录测试""" + from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock diff --git a/tests/test_queue_timeout_budgets.py b/tests/test_queue_timeout_budgets.py index 245264ab..44a94807 100644 --- a/tests/test_queue_timeout_budgets.py +++ b/tests/test_queue_timeout_budgets.py @@ -152,7 +152,9 @@ def test_tool_invoke_proxy_timeout_uses_local_schema_sets( lambda: {"messages.send_message"}, ) - assert runtime_routes._tool_invoke_proxy_timeout_seconds("custom_agent_runner") is None + assert ( + runtime_routes._tool_invoke_proxy_timeout_seconds("custom_agent_runner") is None + ) assert ( runtime_routes._tool_invoke_proxy_timeout_seconds("messages.send_message") == 180.0 From fb562328ce23eccef2e6f122cca18f4e7b913e0a Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Fri, 3 Apr 2026 09:13:45 +0800 Subject: [PATCH 18/21] fix(api): preserve webui attachment scope Make WebUI session scope explicit in runtime chat context and add regression tests for WebUI-scoped image embeds. Co-authored-by: GPT-5.4 xhigh --- src/Undefined/api/app.py | 2 ++ tests/test_runtime_api_chat_stream.py | 4 ++- tests/test_send_message_tool.py | 45 +++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index f1cad772..e8656004 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -1299,6 +1299,8 @@ def send_message_callback( if value is not None: ctx.set_resource(key, value) ctx.set_resource("queue_lane", QUEUE_LANE_SUPERADMIN) + ctx.set_resource("webui_session", True) + ctx.set_resource("webui_permission", "superadmin") result = await self._ctx.ai.ask( full_question, diff --git a/tests/test_runtime_api_chat_stream.py b/tests/test_runtime_api_chat_stream.py index 693bd5ea..c1f6e995 100644 --- a/tests/test_runtime_api_chat_stream.py +++ b/tests/test_runtime_api_chat_stream.py @@ -129,13 +129,14 @@ async def test_run_webui_chat_avoids_extra_blank_line_without_attachments( monkeypatch: pytest.MonkeyPatch, ) -> None: captured_prompt: dict[str, str] = {} + captured_extra_context: dict[str, Any] = {} async def _fake_register_message_attachments(**kwargs: Any) -> Any: _ = kwargs return SimpleNamespace(normalized_text="hello", attachments=[]) async def _fake_ask(full_question: str, **kwargs: Any) -> str: - _ = kwargs + captured_extra_context.update(dict(kwargs.get("extra_context") or {})) captured_prompt["full_question"] = full_question return "" @@ -188,3 +189,4 @@ async def _send_output(user_id: int, message: str) -> None: assert result == "chat" assert sent_messages == [] assert "\n\n " not in captured_prompt["full_question"] + assert captured_extra_context["webui_session"] is True diff --git a/tests/test_send_message_tool.py b/tests/test_send_message_tool.py index 12305b62..f94f9e62 100644 --- a/tests/test_send_message_tool.py +++ b/tests/test_send_message_tool.py @@ -206,3 +206,48 @@ async def test_send_message_renders_pic_uid_before_sending(tmp_path: Path) -> No assert sent_args.kwargs["history_message"] == ( f"图文并茂\n[图片 uid={record.uid} name=demo.png]\n结束" ) + + +@pytest.mark.asyncio +async def test_send_message_renders_webui_scoped_pic_uid_before_sending( + tmp_path: Path, +) -> None: + sender = SimpleNamespace( + send_group_message=AsyncMock(), + send_private_message=AsyncMock(return_value=88888), + ) + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + record = await registry.register_bytes( + "webui", + b"\x89PNG\r\n\x1a\n", + kind="image", + display_name="webui.png", + source_kind="test", + ) + context: dict[str, Any] = { + "request_type": "private", + "user_id": 42, + "sender_id": 10001, + "request_id": "req-webui-1", + "runtime_config": _build_runtime_config(), + "sender": sender, + "attachment_registry": registry, + "webui_session": True, + } + + result = await execute( + { + "message": f'WebUI 图片\n\n结束', + }, + context, + ) + + assert result == "消息已发送(message_id=88888)" + sent_args = sender.send_private_message.await_args + assert "[CQ:image,file=file://" in sent_args.args[1] + assert sent_args.kwargs["history_message"] == ( + f"WebUI 图片\n[图片 uid={record.uid} name=webui.png]\n结束" + ) From 88343f7c254b6383ac8a5c638524166603559773 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Fri, 3 Apr 2026 09:23:48 +0800 Subject: [PATCH 19/21] perf(prompt): move stable rules ahead of dynamic context Reorder prompt construction so stable instruction blocks stay ahead of frequently changing context, improving prompt-cache friendliness without changing prompt content. Co-authored-by: GPT-5.4 xhigh --- src/Undefined/ai/prompts.py | 18 +-- tests/test_prompt_builder_message_order.py | 180 +++++++++++++++++++++ 2 files changed, 189 insertions(+), 9 deletions(-) create mode 100644 tests/test_prompt_builder_message_order.py diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index d9fc16d1..a375d0bc 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -327,6 +327,15 @@ async def build_messages( len(self._anthropic_skill_registry.get_all_skills()), ) + each_rules = await self._load_each_rules() + if each_rules: + messages.append( + { + "role": "system", + "content": f"【强制规则 - 必须在进行任何操作前仔细阅读并严格遵守】\n{each_rules}", + } + ) + if self._memory_storage: memories = self._memory_storage.get_all() if memories: @@ -520,15 +529,6 @@ async def build_messages( } ) - each_rules = await self._load_each_rules() - if each_rules: - messages.append( - { - "role": "system", - "content": f"【强制规则 - 必须在进行任何操作前仔细阅读并严格遵守】\n{each_rules}", - } - ) - messages.append({"role": "user", "content": f"【当前消息】\n{question}"}) logger.debug( "[Prompt] messages_ready=%s question_len=%s", diff --git a/tests/test_prompt_builder_message_order.py b/tests/test_prompt_builder_message_order.py new file mode 100644 index 00000000..338f9257 --- /dev/null +++ b/tests/test_prompt_builder_message_order.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +from Undefined.ai.prompts import PromptBuilder +from Undefined.end_summary_storage import EndSummaryRecord +from Undefined.memory import Memory + + +class _FakeEndSummaryStorage: + async def load(self) -> list[EndSummaryRecord]: + return [ + { + "summary": "刚刚帮用户定位完问题", + "timestamp": "2026-04-03 10:00:00", + } + ] + + +class _FakeCognitiveService: + enabled = True + + async def build_context(self, **kwargs: Any) -> str: + _ = kwargs + return "【认知记忆上下文】\n用户最近在排查缓存命中问题。" + + +@dataclass +class _FakeAnthropicSkill: + name: str + + +class _FakeAnthropicSkillRegistry: + def has_skills(self) -> bool: + return True + + def build_metadata_xml(self) -> str: + return '' + + def get_all_skills(self) -> list[_FakeAnthropicSkill]: + return [_FakeAnthropicSkill(name="demo_skill")] + + +class _FakeMemoryStorage: + def get_all(self) -> list[Memory]: + return [ + Memory( + uuid="mem-1", + fact="用户喜欢详细解释", + created_at="2026-04-03 09:00:00", + ) + ] + + +def _make_builder() -> PromptBuilder: + runtime_config = SimpleNamespace( + keyword_reply_enabled=True, + chat_model=SimpleNamespace( + model_name="gpt-5.4", + pool=SimpleNamespace(enabled=False), + thinking_enabled=False, + reasoning_enabled=True, + ), + vision_model=SimpleNamespace(model_name="gpt-4.1-mini"), + agent_model=SimpleNamespace(model_name="gpt-5.4-mini"), + embedding_model=SimpleNamespace(model_name="text-embedding-3-small"), + security_model=SimpleNamespace(model_name="gpt-4.1-mini"), + grok_model=SimpleNamespace(model_name="grok-4-search"), + cognitive=SimpleNamespace(enabled=True, recent_end_summaries_inject_k=1), + ) + return PromptBuilder( + bot_qq=123456, + memory_storage=cast(Any, _FakeMemoryStorage()), + end_summary_storage=cast(Any, _FakeEndSummaryStorage()), + runtime_config_getter=lambda: runtime_config, + anthropic_skill_registry=cast(Any, _FakeAnthropicSkillRegistry()), + cognitive_service=cast(Any, _FakeCognitiveService()), + ) + + +@pytest.mark.asyncio +async def test_build_messages_places_each_rules_before_dynamic_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "每次都要先检查缓存" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "研发群", + "timestamp": "2026-04-03 10:01:00", + "message": "上一条消息", + "attachments": [], + "role": "member", + "title": "", + } + ] + + messages = await builder.build_messages( + '\n这次缓存为什么没命中?\n', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "group_id": 20001, + "sender_id": 10001, + "sender_name": "测试用户", + "group_name": "研发群", + "request_type": "group", + }, + ) + + labels = { + "skills": "【可用的 Anthropic Skills】", + "rules": "【强制规则 - 必须在进行任何操作前仔细阅读并严格遵守】", + "memory": "【memory.* 手动长期记忆(可编辑)】", + "cognitive": "【认知记忆上下文】", + "summary": "【短期行动记录(最近 1 条,带时间)】", + "history": "【历史消息存档】", + "time": "【当前时间】", + "current": "【当前消息】", + } + positions = { + name: next( + idx + for idx, message in enumerate(messages) + if marker in str(message.get("content", "")) + ) + for name, marker in labels.items() + } + + assert positions["skills"] < positions["rules"] < positions["memory"] + assert positions["memory"] < positions["cognitive"] < positions["summary"] + assert positions["summary"] < positions["history"] < positions["time"] + assert positions["time"] < positions["current"] + + +@pytest.mark.asyncio +async def test_build_messages_keeps_current_message_as_last_item( + monkeypatch: pytest.MonkeyPatch, +) -> None: + builder = PromptBuilder( + bot_qq=0, + memory_storage=None, + end_summary_storage=cast(Any, _FakeEndSummaryStorage()), + ) + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "固定规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + messages = await builder.build_messages("直接提问:缓存是否命中?") + + assert messages[-1] == { + "role": "user", + "content": "【当前消息】\n直接提问:缓存是否命中?", + } From 651eda212b3eb62121e557fbc6f83939c0cc25aa Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Fri, 3 Apr 2026 10:24:36 +0800 Subject: [PATCH 20/21] fix(runtime): tighten timeout and attachment registry handling Restore bounded WebUI tool invoke timeouts for non-agent tools, prune and flush attachment registry state, deduplicate profile name refresh scheduling, pass history_message in group auto replies, and simplify ai_draw_one request param flow. Co-authored-by: GPT-5.4 xhigh --- src/Undefined/ai/client.py | 7 + src/Undefined/attachments.py | 107 +++++++++++++++- src/Undefined/handlers.py | 121 ++++++++++++------ src/Undefined/services/ai_coordinator.py | 5 +- .../tools/ai_draw_one/handler.py | 11 +- src/Undefined/webui/routes/_runtime.py | 55 ++++---- tests/test_ai_coordinator_queue_routing.py | 49 +++++++ tests/test_attachments.py | 57 ++++++++- tests/test_handlers_poke_history.py | 37 ++++++ tests/test_queue_timeout_budgets.py | 26 +++- 10 files changed, 387 insertions(+), 88 deletions(-) diff --git a/src/Undefined/ai/client.py b/src/Undefined/ai/client.py index 7c6beb3a..5700f2b5 100644 --- a/src/Undefined/ai/client.py +++ b/src/Undefined/ai/client.py @@ -327,6 +327,13 @@ async def close(self) -> None: if hasattr(self, "anthropic_skill_registry"): await self.anthropic_skill_registry.stop_hot_reload() + attachment_registry = getattr(self, "attachment_registry", None) + if attachment_registry is not None and hasattr(attachment_registry, "flush"): + try: + await attachment_registry.flush() + except Exception as exc: + logger.warning("[清理] 刷新附件注册表失败: %s", exc) + # 3) 最后关闭共享 HTTP client if hasattr(self, "_http_client"): logger.info("[清理] 正在关闭 AIClient HTTP 客户端...") diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py index 4cbfd93d..2d8e2603 100644 --- a/src/Undefined/attachments.py +++ b/src/Undefined/attachments.py @@ -13,6 +13,7 @@ import mimetypes from pathlib import Path import re +import time from typing import Any, Awaitable, Callable, Mapping, Sequence from urllib.parse import unquote, urlsplit @@ -59,6 +60,8 @@ (b"BM", ".bmp"), ) _FORWARD_ATTACHMENT_MAX_DEPTH = 3 +_ATTACHMENT_CACHE_MAX_AGE_SECONDS = 7 * 24 * 60 * 60 +_ATTACHMENT_REGISTRY_MAX_RECORDS = 2000 @dataclass(frozen=True) @@ -141,8 +144,6 @@ def build_attachment_scope( request_type_text = str(request_type or "").strip().lower() if request_type_text == "private" and user is not None: return f"private:{user}" - if request_type_text == "group" and group is not None: - return f"group:{group}" if user is not None: return f"private:{user}" return None @@ -379,16 +380,105 @@ def __init__( registry_path: Path = ATTACHMENT_REGISTRY_FILE, cache_dir: Path = ATTACHMENT_CACHE_DIR, http_client: httpx.AsyncClient | None = None, + max_records: int = _ATTACHMENT_REGISTRY_MAX_RECORDS, + max_age_seconds: int = _ATTACHMENT_CACHE_MAX_AGE_SECONDS, ) -> None: self._registry_path = registry_path self._cache_dir = cache_dir self._http_client = http_client + self._max_records = max(0, int(max_records)) + self._max_age_seconds = max(0, int(max_age_seconds)) self._lock = asyncio.Lock() self._records: dict[str, AttachmentRecord] = {} self._loaded = False self._load_task: asyncio.Task[None] | None = None self._load_from_disk() + def _resolve_managed_cache_path(self, raw_path: str | None) -> Path | None: + text = str(raw_path or "").strip() + if not text: + return None + try: + path = Path(text).expanduser().resolve() + cache_root = self._cache_dir.resolve() + except Exception: + return None + if path == cache_root or cache_root not in path.parents: + return None + return path + + def _prune_records(self) -> bool: + dirty = False + now = time.time() + retained: list[tuple[str, AttachmentRecord, Path | None, float]] = [] + removable_paths: set[Path] = set() + + for uid, record in self._records.items(): + cache_path = self._resolve_managed_cache_path(record.local_path) + if cache_path is None or not cache_path.is_file(): + dirty = True + continue + try: + mtime = float(cache_path.stat().st_mtime) + except OSError: + dirty = True + removable_paths.add(cache_path) + continue + if self._max_age_seconds > 0 and now - mtime > self._max_age_seconds: + dirty = True + removable_paths.add(cache_path) + continue + retained.append((uid, record, cache_path, mtime)) + + if self._max_records > 0 and len(retained) > self._max_records: + retained.sort(key=lambda item: item[3]) + overflow = len(retained) - self._max_records + for _uid, _record, cache_path, _mtime in retained[:overflow]: + if cache_path is not None: + removable_paths.add(cache_path) + retained = retained[overflow:] + dirty = True + + retained_records = {uid: record for uid, record, _path, _mtime in retained} + retained_paths = { + path.resolve() + for _uid, _record, path, _mtime in retained + if path is not None and path.exists() + } + + for path in removable_paths: + try: + resolved = path.resolve() + except Exception: + resolved = path + if resolved in retained_paths: + continue + try: + path.unlink(missing_ok=True) + dirty = True + except OSError: + continue + + if self._cache_dir.exists(): + for item in self._cache_dir.iterdir(): + if not item.is_file(): + continue + try: + resolved = item.resolve() + except Exception: + resolved = item + if resolved in retained_paths: + continue + try: + item.unlink() + dirty = True + except OSError: + continue + + if dirty: + self._records = retained_records + return dirty + def _load_records_from_payload(self, raw: Any) -> dict[str, AttachmentRecord]: if not isinstance(raw, dict): return {} @@ -428,6 +518,8 @@ def _load_from_disk(self) -> None: def _load_from_disk_sync(self) -> None: if not self._registry_path.exists(): + self._records = {} + self._prune_records() self._loaded = True return try: @@ -437,6 +529,7 @@ def _load_from_disk_sync(self) -> None: self._loaded = True return self._records = self._load_records_from_payload(raw) + self._prune_records() self._loaded = True async def _load_from_disk_async(self) -> None: @@ -447,6 +540,9 @@ async def _load_from_disk_async(self) -> None: self._loaded = True return self._records = self._load_records_from_payload(raw) + dirty = self._prune_records() + if dirty: + await self._persist() self._loaded = True async def load(self) -> None: @@ -461,6 +557,12 @@ async def _persist(self) -> None: payload = {uid: asdict(record) for uid, record in self._records.items()} await io.write_json(self._registry_path, payload, use_lock=True) + async def flush(self) -> None: + """将当前注册表状态强制落盘。""" + await self.load() + async with self._lock: + await self._persist() + def get(self, uid: str) -> AttachmentRecord | None: return self._records.get(str(uid).strip()) @@ -531,6 +633,7 @@ def _write() -> str: created_at=_now_iso(), ) self._records[uid] = record + self._prune_records() await self._persist() return record diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 7d0a73a8..38cbdf20 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -118,6 +118,7 @@ def __init__( ) self._background_tasks: set[asyncio.Task[None]] = set() + self._profile_name_refresh_cache: dict[tuple[str, int], str] = {} # 启动队列 self.ai_coordinator.queue_manager.start(self.ai_coordinator.execute_reply) @@ -187,6 +188,62 @@ def _can_refresh_profile_display_names(self) -> bool: cognitive_service = getattr(ai_client, "_cognitive_service", None) return bool(cognitive_service and getattr(cognitive_service, "enabled", False)) + def _schedule_profile_display_name_refresh( + self, + *, + task_name: str, + sender_id: int | None = None, + sender_name: str = "", + group_id: int | None = None, + group_name: str = "", + ) -> None: + if not self._can_refresh_profile_display_names(): + return + + cache = getattr(self, "_profile_name_refresh_cache", None) + if cache is None: + cache = {} + self._profile_name_refresh_cache = cache + + updates: dict[str, Any] = {} + rollback: list[tuple[tuple[str, int], str | None]] = [] + + normalized_sender_name = sender_name.strip() + if sender_id and normalized_sender_name: + sender_key = ("user", int(sender_id)) + previous = cache.get(sender_key) + if previous != normalized_sender_name: + cache[sender_key] = normalized_sender_name + rollback.append((sender_key, previous)) + updates["sender_id"] = sender_id + updates["sender_name"] = normalized_sender_name + + normalized_group_name = group_name.strip() + if group_id and normalized_group_name: + group_key = ("group", int(group_id)) + previous = cache.get(group_key) + if previous != normalized_group_name: + cache[group_key] = normalized_group_name + rollback.append((group_key, previous)) + updates["group_id"] = group_id + updates["group_name"] = normalized_group_name + + if not updates: + return + + async def _run_refresh() -> None: + try: + await self._refresh_profile_display_names(**updates) + except Exception: + for key, previous in rollback: + if previous is None: + cache.pop(key, None) + else: + cache[key] = previous + raise + + self._spawn_background_task(task_name, _run_refresh()) + async def handle_message(self, event: dict[str, Any]) -> None: """处理收到的消息事件""" if logger.isEnabledFor(logging.DEBUG): @@ -330,14 +387,11 @@ async def handle_message(self, event: dict[str, Any]) -> None: safe_text[:100], ) resolved_private_name = (user_name or private_sender_nickname or "").strip() - if resolved_private_name and self._can_refresh_profile_display_names(): - self._spawn_background_task( - f"profile_name_refresh_private:{private_sender_id}", - self._refresh_profile_display_names( - sender_id=private_sender_id, - sender_name=resolved_private_name, - ), - ) + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_private:{private_sender_id}", + sender_id=private_sender_id, + sender_name=resolved_private_name, + ) # 保存私聊消息到历史记录(保存处理后的内容) # 使用新的工具函数解析内容 @@ -480,18 +534,13 @@ async def handle_message(self, event: dict[str, Any]) -> None: except Exception as e: logger.warning(f"获取群聊名失败: {e}") resolved_group_sender_name = (sender_card or sender_nickname or "").strip() - if (resolved_group_sender_name or str(group_name or "").strip()) and ( - self._can_refresh_profile_display_names() - ): - self._spawn_background_task( - f"profile_name_refresh_group:{group_id}:{sender_id}", - self._refresh_profile_display_names( - sender_id=sender_id, - sender_name=resolved_group_sender_name, - group_id=group_id, - group_name=str(group_name or "").strip(), - ), - ) + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_group:{group_id}:{sender_id}", + sender_id=sender_id, + sender_name=resolved_group_sender_name, + group_id=group_id, + group_name=str(group_name or "").strip(), + ) # 使用新的 utils parsed_content = await parse_message_content_for_history( @@ -652,14 +701,11 @@ async def _record_private_poke_history( display_name = resolved_sender_name or f"QQ{user_id}" normalized_user_name = user_name or display_name poke_text = _format_poke_history_text(display_name, user_id) - if resolved_sender_name and self._can_refresh_profile_display_names(): - self._spawn_background_task( - f"profile_name_refresh_private_poke:{user_id}", - self._refresh_profile_display_names( - sender_id=user_id, - sender_name=resolved_sender_name, - ), - ) + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_private_poke:{user_id}", + sender_id=user_id, + sender_name=resolved_sender_name, + ) try: await self.history_manager.add_private_message( @@ -731,18 +777,13 @@ async def _record_group_poke_history( display_name = resolved_sender_name or f"QQ{sender_id}" poke_text = _format_poke_history_text(display_name, sender_id) normalized_group_name = resolved_group_name or f"群{group_id}" - if (resolved_sender_name or resolved_group_name) and ( - self._can_refresh_profile_display_names() - ): - self._spawn_background_task( - f"profile_name_refresh_group_poke:{group_id}:{sender_id}", - self._refresh_profile_display_names( - sender_id=sender_id, - sender_name=resolved_sender_name, - group_id=group_id, - group_name=resolved_group_name, - ), - ) + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_group_poke:{group_id}:{sender_id}", + sender_id=sender_id, + sender_name=resolved_sender_name, + group_id=group_id, + group_name=resolved_group_name, + ) try: await self.history_manager.add_group_message( diff --git a/src/Undefined/services/ai_coordinator.py b/src/Undefined/services/ai_coordinator.py index 098d575f..42467ec0 100644 --- a/src/Undefined/services/ai_coordinator.py +++ b/src/Undefined/services/ai_coordinator.py @@ -271,7 +271,10 @@ async def _execute_auto_reply(self, request: dict[str, Any]) -> None: async def send_msg_cb(message: str, reply_to: int | None = None) -> None: await self.sender.send_group_message( - group_id, message, reply_to=reply_to + group_id, + message, + reply_to=reply_to, + history_message=message, ) async def get_recent_cb( diff --git a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py index d5e07b1b..0aa440c7 100644 --- a/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py +++ b/src/Undefined/skills/agents/entertainment_agent/tools/ai_draw_one/handler.py @@ -446,19 +446,11 @@ async def _call_openai_models( response_format: str, n: int | None, timeout_val: float, + extra_params: dict[str, Any], context: dict[str, Any], ) -> _GeneratedImagePayload | str: """调用 OpenAI 兼容的图片生成接口""" - # 追加 request_params - extra_params: dict[str, Any] = {} - try: - from Undefined.config import get_config - - extra_params = get_config(strict=False).models_image_gen.request_params - except Exception: - extra_params = {} - body = _build_openai_models_request_body( prompt=prompt, model_name=model_name, @@ -878,6 +870,7 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: response_format=response_format, n=n_value, timeout_val=timeout_val, + extra_params=gen_cfg.request_params, context=context, ) else: diff --git a/src/Undefined/webui/routes/_runtime.py b/src/Undefined/webui/routes/_runtime.py index fa740566..c6de3ebb 100644 --- a/src/Undefined/webui/routes/_runtime.py +++ b/src/Undefined/webui/routes/_runtime.py @@ -36,50 +36,47 @@ def _chat_proxy_timeout_seconds() -> float: return compute_queued_llm_timeout_seconds(cfg, cfg.chat_model) -def _load_local_function_names(*roots: Path) -> set[str]: +def _load_function_name(config_path: Path) -> str | None: + try: + raw = json.loads(config_path.read_text(encoding="utf-8")) + except Exception: + return None + function = raw.get("function", {}) + if not isinstance(function, dict): + return None + name = str(function.get("name", "") or "").strip() + return name or None + + +def _load_top_level_agent_names(root: Path) -> set[str]: names: set[str] = set() - for root in roots: - if not root.exists(): + if not root.exists(): + return names + for item_dir in root.iterdir(): + if not item_dir.is_dir() or item_dir.name.startswith("_"): continue - for config_path in root.rglob("config.json"): - try: - relative_parts = config_path.relative_to(root).parts - except ValueError: - relative_parts = config_path.parts - if any(part.startswith("_") for part in relative_parts): - continue - try: - raw = json.loads(config_path.read_text(encoding="utf-8")) - except Exception: - continue - function = raw.get("function", {}) - if not isinstance(function, dict): - continue - name = str(function.get("name", "") or "").strip() - if name: - names.add(name) + config_path = item_dir / "config.json" + if not config_path.exists(): + continue + name = _load_function_name(config_path) + if name: + names.add(name) return names def _get_local_agent_tool_names() -> set[str]: skills_root = Path(__file__).resolve().parents[2] / "skills" - return _load_local_function_names(skills_root / "agents") - - -def _get_local_tool_names() -> set[str]: - skills_root = Path(__file__).resolve().parents[2] / "skills" - return _load_local_function_names(skills_root / "tools", skills_root / "toolsets") + return _load_top_level_agent_names(skills_root / "agents") def _tool_invoke_proxy_timeout_seconds(tool_name: str) -> float | None: normalized_name = str(tool_name or "").strip() if normalized_name in _get_local_agent_tool_names(): return None - if normalized_name not in _get_local_tool_names(): - return None cfg = get_config(strict=False) - # 普通 tool 保持 Runtime API 超时 + 60s 网络缓冲。 + # 非 agent 一律保留 Runtime API 超时 + 60s 网络缓冲, + # 包括 toolsets、MCP/external tools 以及本地未知名称。 return float(cfg.api.tool_invoke_timeout) + 60.0 diff --git a/tests/test_ai_coordinator_queue_routing.py b/tests/test_ai_coordinator_queue_routing.py index 16266360..c8d68233 100644 --- a/tests/test_ai_coordinator_queue_routing.py +++ b/tests/test_ai_coordinator_queue_routing.py @@ -6,6 +6,7 @@ import pytest +from Undefined.services import ai_coordinator as ai_coordinator_module from Undefined.services.ai_coordinator import AICoordinator @@ -190,3 +191,51 @@ async def test_handle_private_reply_avoids_extra_blank_line_without_attachments( assert await_args is not None request_data = await_args.args[0] assert "\n\n " not in request_data["full_question"] + + +@pytest.mark.asyncio +async def test_execute_auto_reply_send_msg_cb_passes_history_message( + monkeypatch: pytest.MonkeyPatch, +) -> None: + coordinator: Any = object.__new__(AICoordinator) + sender = SimpleNamespace(send_group_message=AsyncMock()) + + async def _fake_ask(*_args: Any, **kwargs: Any) -> str: + await kwargs["send_message_callback"]("hello group") + return "" + + coordinator.config = SimpleNamespace(bot_qq=10000) + coordinator.sender = sender + coordinator.ai = SimpleNamespace( + ask=_fake_ask, + memory_storage=SimpleNamespace(), + runtime_config=SimpleNamespace(), + ) + coordinator.history_manager = SimpleNamespace() + coordinator.onebot = SimpleNamespace( + get_image=AsyncMock(), + get_forward_msg=AsyncMock(), + send_like=AsyncMock(), + ) + coordinator.scheduler = SimpleNamespace() + + monkeypatch.setattr( + ai_coordinator_module, "collect_context_resources", lambda _vars: {} + ) + + await coordinator._execute_auto_reply( + { + "group_id": 12345, + "sender_id": 20001, + "sender_name": "member", + "group_name": "测试群", + "full_question": "prompt", + } + ) + + sender.send_group_message.assert_awaited_once_with( + 12345, + "hello group", + reply_to=None, + history_message="hello group", + ) diff --git a/tests/test_attachments.py b/tests/test_attachments.py index d2377d15..edf5a9a2 100644 --- a/tests/test_attachments.py +++ b/tests/test_attachments.py @@ -54,6 +54,9 @@ async def test_attachment_registry_load_uses_async_read_json( ) -> None: registry_path = tmp_path / "attachment_registry.json" cache_dir = tmp_path / "attachments" + cache_dir.mkdir(parents=True, exist_ok=True) + cached_file = cache_dir / "pic_async123.png" + cached_file.write_bytes(_PNG_BYTES) seen_calls: list[tuple[Path, bool]] = [] payload = { "pic_async123": { @@ -64,7 +67,7 @@ async def test_attachment_registry_load_uses_async_read_json( "display_name": "cat.png", "source_kind": "test", "source_ref": "test", - "local_path": str(cache_dir / "pic_async123.png"), + "local_path": str(cached_file), "mime_type": "image/png", "sha256": "digest", "created_at": "2026-04-02T00:00:00", @@ -175,3 +178,55 @@ async def test_render_message_with_pic_placeholders_uses_file_uri_and_shadow_tex assert "[CQ:image,file=file://" in rendered.delivery_text assert f"[图片 uid={record.uid} name=cat.png]" in rendered.history_text + + +@pytest.mark.asyncio +async def test_attachment_registry_prunes_old_records_and_files(tmp_path: Path) -> None: + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + max_records=1, + ) + + first = await registry.register_bytes( + "group:10001", + _PNG_BYTES, + kind="image", + display_name="first.png", + source_kind="test", + ) + first_path = Path(str(first.local_path)) + second = await registry.register_bytes( + "group:10001", + _PNG_BYTES + b"2", + kind="image", + display_name="second.png", + source_kind="test", + ) + + assert registry.resolve(first.uid, "group:10001") is None + assert registry.resolve(second.uid, "group:10001") is not None + assert first_path.exists() is False + cache_files = [ + item for item in (tmp_path / "attachments").iterdir() if item.is_file() + ] + assert len(cache_files) == 1 + assert cache_files[0].name.startswith(second.uid) + + +@pytest.mark.asyncio +async def test_attachment_registry_load_prunes_orphan_cache_files( + tmp_path: Path, +) -> None: + cache_dir = tmp_path / "attachments" + cache_dir.mkdir(parents=True, exist_ok=True) + orphan = cache_dir / "orphan.png" + orphan.write_bytes(_PNG_BYTES) + + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=cache_dir, + ) + await registry.load() + + assert orphan.exists() is False diff --git a/tests/test_handlers_poke_history.py b/tests/test_handlers_poke_history.py index ede69a23..1ec55e51 100644 --- a/tests/test_handlers_poke_history.py +++ b/tests/test_handlers_poke_history.py @@ -189,3 +189,40 @@ def _fake_spawn(name: str, coroutine: Any) -> None: group_history_call.kwargs["text_content"] == "QQ20001(暱称)[20001(QQ号)] 拍了拍你。" ) + + +@pytest.mark.asyncio +async def test_schedule_profile_display_name_refresh_deduplicates_same_name() -> None: + handler = _build_handler() + handler.ai = SimpleNamespace(_cognitive_service=SimpleNamespace(enabled=True)) + handler._refresh_profile_display_names = AsyncMock() + scheduled: list[tuple[str, Any]] = [] + + def _fake_spawn(name: str, coroutine: Any) -> None: + scheduled.append((name, coroutine)) + + handler._spawn_background_task = _fake_spawn + + handler._schedule_profile_display_name_refresh( + task_name="profile_name_refresh_group:30001:20001", + sender_id=20001, + sender_name="群名片", + group_id=30001, + group_name="测试群", + ) + handler._schedule_profile_display_name_refresh( + task_name="profile_name_refresh_group:30001:20001", + sender_id=20001, + sender_name="群名片", + group_id=30001, + group_name="测试群", + ) + + assert len(scheduled) == 1 + await scheduled[0][1] + handler._refresh_profile_display_names.assert_awaited_once_with( + sender_id=20001, + sender_name="群名片", + group_id=30001, + group_name="测试群", + ) diff --git a/tests/test_queue_timeout_budgets.py b/tests/test_queue_timeout_budgets.py index 44a94807..31eda12c 100644 --- a/tests/test_queue_timeout_budgets.py +++ b/tests/test_queue_timeout_budgets.py @@ -146,11 +146,6 @@ def test_tool_invoke_proxy_timeout_uses_local_schema_sets( "_get_local_agent_tool_names", lambda: {"custom_agent_runner"}, ) - monkeypatch.setattr( - runtime_routes, - "_get_local_tool_names", - lambda: {"messages.send_message"}, - ) assert ( runtime_routes._tool_invoke_proxy_timeout_seconds("custom_agent_runner") is None @@ -159,4 +154,23 @@ def test_tool_invoke_proxy_timeout_uses_local_schema_sets( runtime_routes._tool_invoke_proxy_timeout_seconds("messages.send_message") == 180.0 ) - assert runtime_routes._tool_invoke_proxy_timeout_seconds("unknown_tool") is None + assert runtime_routes._tool_invoke_proxy_timeout_seconds("unknown_tool") == 180.0 + + +def test_load_top_level_agent_names_ignores_nested_agent_tools(tmp_path: Path) -> None: + agent_dir = tmp_path / "demo_agent" + agent_dir.mkdir() + (agent_dir / "config.json").write_text( + '{"function":{"name":"demo_agent"}}', + encoding="utf-8", + ) + nested_tool_dir = agent_dir / "tools" / "helper" + nested_tool_dir.mkdir(parents=True) + (nested_tool_dir / "config.json").write_text( + '{"function":{"name":"helper_tool"}}', + encoding="utf-8", + ) + + names = runtime_routes._load_top_level_agent_names(tmp_path) + + assert names == {"demo_agent"} From fa77d8d28770e7707ad36a4f7b2f72aa3bac0d19 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Fri, 3 Apr 2026 10:44:32 +0800 Subject: [PATCH 21/21] fix(attachments): remove constructor disk load Stop AttachmentRegistry from reading disk during construction and keep initial registry loading explicit and async. Co-authored-by: GPT-5.4 xhigh --- src/Undefined/attachments.py | 26 -------------------------- tests/test_attachments.py | 1 + 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py index 2d8e2603..588247bc 100644 --- a/src/Undefined/attachments.py +++ b/src/Undefined/attachments.py @@ -8,7 +8,6 @@ from dataclasses import asdict, dataclass from datetime import datetime import hashlib -import json import logging import mimetypes from pathlib import Path @@ -392,7 +391,6 @@ def __init__( self._records: dict[str, AttachmentRecord] = {} self._loaded = False self._load_task: asyncio.Task[None] | None = None - self._load_from_disk() def _resolve_managed_cache_path(self, raw_path: str | None) -> Path | None: text = str(raw_path or "").strip() @@ -508,30 +506,6 @@ def _load_records_from_payload(self, raw: Any) -> dict[str, AttachmentRecord]: continue return loaded - def _load_from_disk(self) -> None: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - self._load_from_disk_sync() - return - self._load_task = loop.create_task(self._load_from_disk_async()) - - def _load_from_disk_sync(self) -> None: - if not self._registry_path.exists(): - self._records = {} - self._prune_records() - self._loaded = True - return - try: - raw = json.loads(self._registry_path.read_text(encoding="utf-8")) - except Exception as exc: - logger.warning("[AttachmentRegistry] 读取失败: %s", exc) - self._loaded = True - return - self._records = self._load_records_from_payload(raw) - self._prune_records() - self._loaded = True - async def _load_from_disk_async(self) -> None: try: raw = await io.read_json(self._registry_path, use_lock=False) diff --git a/tests/test_attachments.py b/tests/test_attachments.py index edf5a9a2..1edf17a1 100644 --- a/tests/test_attachments.py +++ b/tests/test_attachments.py @@ -87,6 +87,7 @@ def _unexpected_sync_read_text(_self: Path, *_args: Any, **_kwargs: Any) -> str: monkeypatch.setattr(Path, "read_text", _unexpected_sync_read_text) registry = AttachmentRegistry(registry_path=registry_path, cache_dir=cache_dir) + assert seen_calls == [] await registry.load() assert seen_calls == [(registry_path, False)]