diff --git a/cecli/commands/load_session.py b/cecli/commands/load_session.py index 1d5676d97e9..f3c38396a8b 100644 --- a/cecli/commands/load_session.py +++ b/cecli/commands/load_session.py @@ -18,7 +18,7 @@ async def execute(cls, io, coder, args, **kwargs): from cecli import sessions session_manager = sessions.SessionManager(coder, io) - session_manager.load_session(args.strip()) + await session_manager.load_session(args.strip()) return format_command_result(io, "load-session", f"Loaded session: {args.strip()}") diff --git a/cecli/main.py b/cecli/main.py index bf8b89fa99d..bfebaffc6d1 100644 --- a/cecli/main.py +++ b/cecli/main.py @@ -1198,7 +1198,7 @@ def get_io(pretty): from cecli.sessions import SessionManager session_manager = SessionManager(coder, io) - session_manager.load_session( + await session_manager.load_session( args.auto_save_session_name if args.auto_save_session_name else "auto-save" ) except Exception: diff --git a/cecli/sessions.py b/cecli/sessions.py index 5d8447d5213..c1e9fbdc5f3 100644 --- a/cecli/sessions.py +++ b/cecli/sessions.py @@ -88,7 +88,7 @@ def list_sessions(self) -> List[Dict]: return sessions - def load_session(self, session_identifier: str) -> bool: + async def load_session(self, session_identifier: str) -> bool: """Load a saved session by name or file path.""" if not session_identifier: self.io.tool_error("Please provide a session name or file path.") @@ -112,7 +112,17 @@ def load_session(self, session_identifier: str) -> bool: return False # Apply session data - return self._apply_session_data(session_data, session_file) + applied = await self._apply_session_data(session_data, session_file) + if applied: + from cecli.commands import SwitchCoderSignal + + raise SwitchCoderSignal( + edit_format=self.coder.edit_format, + from_coder=self.coder, + summarize_from_coder=False, + show_announcements=True, + ) + return applied def _build_session_data(self, session_name) -> Dict: """Build session data dictionary from current coder state.""" @@ -140,6 +150,39 @@ def _build_session_data(self, session_name) -> Dict: self.io.tool_warning(f"Could not read todo list file: {e}") # Get CUR and DONE messages from ConversationManager + connected_mcps = [] + if hasattr(self.coder, "mcp_manager") and self.coder.mcp_manager: + connected_mcps = [server.name for server in self.coder.mcp_manager.connected_servers] + + # Get CUR and DONE messages from ConversationManager + connected_mcps = [] + if hasattr(self.coder, "mcp_manager") and self.coder.mcp_manager: + connected_mcps = [server.name for server in self.coder.mcp_manager.connected_servers] + + skills_data = None + if hasattr(self.coder, "skills_manager") and self.coder.skills_manager: + skills_data = { + "skills_paths": [str(p) for p in self.coder.skills_manager.directory_paths], + "skills_includelist": ( + list(self.coder.skills_manager.include_list) + if self.coder.skills_manager.include_list is not None + else [] + ), + "skills_excludelist": ( + list(self.coder.skills_manager.exclude_list) + if self.coder.skills_manager.exclude_list is not None + else [] + ), + } + + agent_config_data = None + if hasattr(self.coder, "agent_config"): + agent_config_data = { + "tools_paths": self.coder.agent_config.get("tools_paths", []), + "tools_includelist": self.coder.agent_config.get("tools_includelist", []), + "tools_excludelist": self.coder.agent_config.get("tools_excludelist", []), + } + return { "version": 1, "session_name": session_name, @@ -168,6 +211,9 @@ def _build_session_data(self, session_name) -> Dict: "auto_test": self.coder.auto_test, }, "todo_list": todo_content, + "mcps": connected_mcps, + "skills": skills_data, + "tools": agent_config_data, } def _find_session_file(self, session_identifier: str) -> Optional[Path]: @@ -194,7 +240,7 @@ def _find_session_file(self, session_identifier: str) -> Optional[Path]: self.io.tool_output("Use /list-sessions to see available sessions.") return None - def _apply_session_data(self, session_data: Dict, session_file: Path) -> bool: + async def _apply_session_data(self, session_data: Dict, session_file: Path) -> bool: """Apply session data to current coder state.""" try: # Clear current state @@ -303,6 +349,40 @@ def _apply_session_data(self, session_data: Dict, session_file: Path) -> bool: ) self.io.tool_output(f"Loaded {num_messages} messages and {num_files} files") + # Load MCPs + saved_mcps = session_data.get("mcps", []) + if hasattr(self.coder, "mcp_manager") and self.coder.mcp_manager: + current_mcps = {server.name for server in self.coder.mcp_manager.connected_servers} + saved_mcps_set = set(saved_mcps) + + to_disconnect = current_mcps - saved_mcps_set + for mcp_name in to_disconnect: + await self.coder.mcp_manager.disconnect_server(mcp_name) + + to_connect = saved_mcps_set - current_mcps + for mcp_name in to_connect: + await self.coder.mcp_manager.connect_server(mcp_name) + + # Load skills + skills_data = session_data.get("skills") + if skills_data and hasattr(self.coder, "skills_manager") and self.coder.skills_manager: + self.coder.skills_manager.directory_paths = skills_data.get("skills_paths", []) + self.coder.skills_manager.include_list = set( + skills_data.get("skills_includelist", []) + ) + self.coder.skills_manager.exclude_list = set( + skills_data.get("skills_excludelist", []) + ) + + # Load tools config + agent_config_data = session_data.get("tools") + if agent_config_data and hasattr(self.coder, "agent_config"): + self.coder.agent_config.update(agent_config_data) + from cecli.tools.utils.registry import ToolRegistry + + ToolRegistry.build_registry(agent_config=self.coder.agent_config) + self.coder.loaded_custom_tools = ToolRegistry.loaded_custom_tools + return True except Exception as e: diff --git a/cecli/tools/edit_text.py b/cecli/tools/edit_text.py index f03fc6d96df..5b4d64f7c3c 100644 --- a/cecli/tools/edit_text.py +++ b/cecli/tools/edit_text.py @@ -56,9 +56,9 @@ class Tool(BaseTool): "type": "string", "enum": ["replace", "delete", "insert"], "description": ( - "The type of operation: 'replace' (replace range with text), " - "'delete' (remove range), or 'insert' (insert text after start_line). " - "Defaults to 'replace'." + "The type of operation: 'replace' (replace range with" + " text), 'delete' (remove range), or 'insert' (insert text" + " after start_line). Defaults to 'replace'." ), }, "text": { @@ -78,8 +78,8 @@ class Tool(BaseTool): "end_line": { "type": "string", "description": ( - 'Hashline format for end line: "{4 char hash}" (without the ' - "braces)" + 'Hashline format for end line: "{4 char hash}" (without the' + " braces)" ), }, }, @@ -179,8 +179,8 @@ def execute( if operation in ("replace", "delete"): if edit_start_line is None: raise ToolError( - f"Edit {edit_index + 1}: 'start_line' parameter is required " - f"for '{operation}' operation" + f"Edit {edit_index + 1}: 'start_line' parameter is required" + f" for '{operation}' operation" ) if edit_end_line is None: raise ToolError( @@ -190,8 +190,8 @@ def execute( if operation == "insert": if edit_start_line is None: raise ToolError( - f"Edit {edit_index + 1}: 'start_line' parameter is required " - "for 'insert' operation" + f"Edit {edit_index + 1}: 'start_line' parameter is required" + " for 'insert' operation" ) # For insert, end_line defaults to start_line edit_end_line = edit_end_line or edit_start_line diff --git a/tests/basic/test_sessions.py b/tests/basic/test_sessions.py index adb5a01a907..aa26f4f5a26 100644 --- a/tests/basic/test_sessions.py +++ b/tests/basic/test_sessions.py @@ -7,7 +7,7 @@ from unittest import TestCase, mock from cecli.coders import Coder -from cecli.commands import Commands +from cecli.commands import Commands, SwitchCoderSignal from cecli.helpers.file_searcher import handle_core_files from cecli.io import InputOutput from cecli.models import Model @@ -196,3 +196,114 @@ async def test_preserve_todo_list_deprecated(self): self.assertTrue( any("deprecated" in call[0][0] for call in mock_tool_warning.call_args_list) ) + + async def test_cmd_save_load_session_agent_config(self): + """Test session save/load for agent-specific configs (mcp, skills, tools).""" + with GitTemporaryDirectory(): + # Mock args for AgentCoder + mock_args = mock.MagicMock() + mock_args.agent_config = json.dumps( + { + "tools_paths": ["/test/tools/path"], + "tools_includelist": ["included_tool"], + "tools_excludelist": ["excluded_tool"], + } + ) + # This is needed for the skills manager to be created + mock_args.skills_paths = ["/test/skills/path"] + mock_args.mcp_servers = json.dumps([{"name": "mock_mcp"}]) + mock_args.mcp_servers_files = [] + mock_args.verbose = False + mock_args.debug = False + mock_args.tui = False + mock_args.auto_save_session_name = "auto-save" + mock_args.auto_save = False + mock_args.auto_load = False + mock_args.yes_always_commands = True + mock_args.command_prefix = None + mock_args.file_diffs = True + mock_args.max_reflections = 3 + mock_args.model = "gpt-3.5-turbo" + mock_args.weak_model = None + mock_args.editor_model = None + mock_args.agent_model = None + mock_args.editor_edit_format = None + mock_args.retries = None + mock_args.reasoning_effort = None + mock_args.thinking_tokens = None + mock_args.check_model_accepts_settings = True + mock_args.copy_paste = False + mock_args.hooks = None + + io = InputOutput(pretty=False, fancy_input=False, yes=True) + + # === SAVE SESSION === + coder_to_save = await Coder.create( + self.GPT35, "agent", io, args=mock_args, repo=mock.MagicMock() + ) + commands_to_save = Commands(io, coder_to_save, args=mock_args) + + # Configure state to be saved + await coder_to_save.mcp_manager.connect_server("mock_mcp") + coder_to_save.skills_manager.include_list = {"included_skill"} + coder_to_save.skills_manager.exclude_list = {"excluded_skill"} + coder_to_save.skills_manager.directory_paths = ["/test/skills/path/saved"] + + session_name = "agent_session" + await commands_to_save.execute("save-session", session_name) + + session_file = Path(handle_core_files(".cecli")) / "sessions" / f"{session_name}.json" + self.assertTrue(session_file.exists()) + + with open(session_file, "r", encoding="utf-8") as f: + saved_data = json.load(f) + + # Assert saved data is correct + self.assertEqual(saved_data["mcps"], ["mock_mcp"]) + self.assertEqual(saved_data["skills"]["skills_paths"], ["/test/skills/path/saved"]) + self.assertEqual(saved_data["skills"]["skills_includelist"], ["included_skill"]) + self.assertEqual(saved_data["skills"]["skills_excludelist"], ["excluded_skill"]) + self.assertEqual(saved_data["tools"]["tools_paths"], ["/test/tools/path"]) + self.assertEqual(saved_data["tools"]["tools_includelist"], ["included_tool"]) + self.assertEqual(saved_data["tools"]["tools_excludelist"], ["excluded_tool"]) + + # === LOAD SESSION === + # Create a new coder to load into, ensuring it's a clean slate + coder_to_load_initial = await Coder.create( + self.GPT35, "agent", io, args=mock_args, repo=mock.MagicMock() + ) + commands_to_load = Commands(io, coder_to_load_initial, args=mock_args) + + # Mock ToolRegistry.build_registry to check if it's called + with mock.patch( + "cecli.tools.utils.registry.ToolRegistry.build_registry" + ) as mock_build_registry: + coder_after_load = None + try: + await commands_to_load.execute("load-session", session_name) + except SwitchCoderSignal as e: + # The SwitchCoderSignal is expected, we need to get the new coder from it + coder_after_load = await Coder.create(**e.kwargs) + + self.assertIsNotNone(coder_after_load) + + # Assert loaded state is correct in the new coder instance + connected_mcps = {s.name for s in coder_after_load.mcp_manager.connected_servers} + self.assertIn("mock_mcp", connected_mcps) + + self.assertEqual( + coder_after_load.skills_manager.directory_paths, ["/test/skills/path/saved"] + ) + self.assertEqual(coder_after_load.skills_manager.include_list, {"included_skill"}) + self.assertEqual(coder_after_load.skills_manager.exclude_list, {"excluded_skill"}) + + self.assertEqual(coder_after_load.agent_config["tools_paths"], ["/test/tools/path"]) + self.assertEqual( + coder_after_load.agent_config["tools_includelist"], ["included_tool"] + ) + self.assertEqual( + coder_after_load.agent_config["tools_excludelist"], ["excluded_tool"] + ) + + # Assert that the tool registry was rebuilt + mock_build_registry.assert_called_with(agent_config=coder_after_load.agent_config)