Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cecli/commands/load_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def execute(cls, io, coder, args, **kwargs):
from cecli import sessions

session_manager = sessions.SessionManager(coder, io)
session_manager.load_session(args.strip())
await session_manager.load_session(args.strip())

return format_command_result(io, "load-session", f"Loaded session: {args.strip()}")

Expand Down
2 changes: 1 addition & 1 deletion cecli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ def get_io(pretty):
from cecli.sessions import SessionManager

session_manager = SessionManager(coder, io)
session_manager.load_session(
await session_manager.load_session(
args.auto_save_session_name if args.auto_save_session_name else "auto-save"
)
except Exception:
Expand Down
86 changes: 83 additions & 3 deletions cecli/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def list_sessions(self) -> List[Dict]:

return sessions

def load_session(self, session_identifier: str) -> bool:
async def load_session(self, session_identifier: str) -> bool:
"""Load a saved session by name or file path."""
if not session_identifier:
self.io.tool_error("Please provide a session name or file path.")
Expand All @@ -112,7 +112,17 @@ def load_session(self, session_identifier: str) -> bool:
return False

# Apply session data
return self._apply_session_data(session_data, session_file)
applied = await self._apply_session_data(session_data, session_file)
if applied:
from cecli.commands import SwitchCoderSignal

raise SwitchCoderSignal(
edit_format=self.coder.edit_format,
from_coder=self.coder,
summarize_from_coder=False,
show_announcements=True,
)
return applied

def _build_session_data(self, session_name) -> Dict:
"""Build session data dictionary from current coder state."""
Expand Down Expand Up @@ -140,6 +150,39 @@ def _build_session_data(self, session_name) -> Dict:
self.io.tool_warning(f"Could not read todo list file: {e}")

# Get CUR and DONE messages from ConversationManager
connected_mcps = []
if hasattr(self.coder, "mcp_manager") and self.coder.mcp_manager:
connected_mcps = [server.name for server in self.coder.mcp_manager.connected_servers]

# Get CUR and DONE messages from ConversationManager
connected_mcps = []
if hasattr(self.coder, "mcp_manager") and self.coder.mcp_manager:
connected_mcps = [server.name for server in self.coder.mcp_manager.connected_servers]

skills_data = None
if hasattr(self.coder, "skills_manager") and self.coder.skills_manager:
skills_data = {
"skills_paths": [str(p) for p in self.coder.skills_manager.directory_paths],
"skills_includelist": (
list(self.coder.skills_manager.include_list)
if self.coder.skills_manager.include_list is not None
else []
),
"skills_excludelist": (
list(self.coder.skills_manager.exclude_list)
if self.coder.skills_manager.exclude_list is not None
else []
),
}

agent_config_data = None
if hasattr(self.coder, "agent_config"):
agent_config_data = {
"tools_paths": self.coder.agent_config.get("tools_paths", []),
"tools_includelist": self.coder.agent_config.get("tools_includelist", []),
"tools_excludelist": self.coder.agent_config.get("tools_excludelist", []),
}

return {
"version": 1,
"session_name": session_name,
Expand Down Expand Up @@ -168,6 +211,9 @@ def _build_session_data(self, session_name) -> Dict:
"auto_test": self.coder.auto_test,
},
"todo_list": todo_content,
"mcps": connected_mcps,
"skills": skills_data,
"tools": agent_config_data,
}

def _find_session_file(self, session_identifier: str) -> Optional[Path]:
Expand All @@ -194,7 +240,7 @@ def _find_session_file(self, session_identifier: str) -> Optional[Path]:
self.io.tool_output("Use /list-sessions to see available sessions.")
return None

def _apply_session_data(self, session_data: Dict, session_file: Path) -> bool:
async def _apply_session_data(self, session_data: Dict, session_file: Path) -> bool:
"""Apply session data to current coder state."""
try:
# Clear current state
Expand Down Expand Up @@ -303,6 +349,40 @@ def _apply_session_data(self, session_data: Dict, session_file: Path) -> bool:
)
self.io.tool_output(f"Loaded {num_messages} messages and {num_files} files")

# Load MCPs
saved_mcps = session_data.get("mcps", [])
if hasattr(self.coder, "mcp_manager") and self.coder.mcp_manager:
current_mcps = {server.name for server in self.coder.mcp_manager.connected_servers}
saved_mcps_set = set(saved_mcps)

to_disconnect = current_mcps - saved_mcps_set
for mcp_name in to_disconnect:
await self.coder.mcp_manager.disconnect_server(mcp_name)

to_connect = saved_mcps_set - current_mcps
for mcp_name in to_connect:
await self.coder.mcp_manager.connect_server(mcp_name)

# Load skills
skills_data = session_data.get("skills")
if skills_data and hasattr(self.coder, "skills_manager") and self.coder.skills_manager:
self.coder.skills_manager.directory_paths = skills_data.get("skills_paths", [])
self.coder.skills_manager.include_list = set(
skills_data.get("skills_includelist", [])
)
self.coder.skills_manager.exclude_list = set(
skills_data.get("skills_excludelist", [])
)

# Load tools config
agent_config_data = session_data.get("tools")
if agent_config_data and hasattr(self.coder, "agent_config"):
self.coder.agent_config.update(agent_config_data)
from cecli.tools.utils.registry import ToolRegistry

ToolRegistry.build_registry(agent_config=self.coder.agent_config)
self.coder.loaded_custom_tools = ToolRegistry.loaded_custom_tools

return True

except Exception as e:
Expand Down
18 changes: 9 additions & 9 deletions cecli/tools/edit_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class Tool(BaseTool):
"type": "string",
"enum": ["replace", "delete", "insert"],
"description": (
"The type of operation: 'replace' (replace range with text), "
"'delete' (remove range), or 'insert' (insert text after start_line). "
"Defaults to 'replace'."
"The type of operation: 'replace' (replace range with"
" text), 'delete' (remove range), or 'insert' (insert text"
" after start_line). Defaults to 'replace'."
),
},
"text": {
Expand All @@ -78,8 +78,8 @@ class Tool(BaseTool):
"end_line": {
"type": "string",
"description": (
'Hashline format for end line: "{4 char hash}" (without the '
"braces)"
'Hashline format for end line: "{4 char hash}" (without the'
" braces)"
),
},
},
Expand Down Expand Up @@ -179,8 +179,8 @@ def execute(
if operation in ("replace", "delete"):
if edit_start_line is None:
raise ToolError(
f"Edit {edit_index + 1}: 'start_line' parameter is required "
f"for '{operation}' operation"
f"Edit {edit_index + 1}: 'start_line' parameter is required"
f" for '{operation}' operation"
)
if edit_end_line is None:
raise ToolError(
Expand All @@ -190,8 +190,8 @@ def execute(
if operation == "insert":
if edit_start_line is None:
raise ToolError(
f"Edit {edit_index + 1}: 'start_line' parameter is required "
"for 'insert' operation"
f"Edit {edit_index + 1}: 'start_line' parameter is required"
" for 'insert' operation"
)
# For insert, end_line defaults to start_line
edit_end_line = edit_end_line or edit_start_line
Expand Down
113 changes: 112 additions & 1 deletion tests/basic/test_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from unittest import TestCase, mock

from cecli.coders import Coder
from cecli.commands import Commands
from cecli.commands import Commands, SwitchCoderSignal
from cecli.helpers.file_searcher import handle_core_files
from cecli.io import InputOutput
from cecli.models import Model
Expand Down Expand Up @@ -196,3 +196,114 @@ async def test_preserve_todo_list_deprecated(self):
self.assertTrue(
any("deprecated" in call[0][0] for call in mock_tool_warning.call_args_list)
)

async def test_cmd_save_load_session_agent_config(self):
"""Test session save/load for agent-specific configs (mcp, skills, tools)."""
with GitTemporaryDirectory():
# Mock args for AgentCoder
mock_args = mock.MagicMock()
mock_args.agent_config = json.dumps(
{
"tools_paths": ["/test/tools/path"],
"tools_includelist": ["included_tool"],
"tools_excludelist": ["excluded_tool"],
}
)
# This is needed for the skills manager to be created
mock_args.skills_paths = ["/test/skills/path"]
mock_args.mcp_servers = json.dumps([{"name": "mock_mcp"}])
mock_args.mcp_servers_files = []
mock_args.verbose = False
mock_args.debug = False
mock_args.tui = False
mock_args.auto_save_session_name = "auto-save"
mock_args.auto_save = False
mock_args.auto_load = False
mock_args.yes_always_commands = True
mock_args.command_prefix = None
mock_args.file_diffs = True
mock_args.max_reflections = 3
mock_args.model = "gpt-3.5-turbo"
mock_args.weak_model = None
mock_args.editor_model = None
mock_args.agent_model = None
mock_args.editor_edit_format = None
mock_args.retries = None
mock_args.reasoning_effort = None
mock_args.thinking_tokens = None
mock_args.check_model_accepts_settings = True
mock_args.copy_paste = False
mock_args.hooks = None

io = InputOutput(pretty=False, fancy_input=False, yes=True)

# === SAVE SESSION ===
coder_to_save = await Coder.create(
self.GPT35, "agent", io, args=mock_args, repo=mock.MagicMock()
)
commands_to_save = Commands(io, coder_to_save, args=mock_args)

# Configure state to be saved
await coder_to_save.mcp_manager.connect_server("mock_mcp")
coder_to_save.skills_manager.include_list = {"included_skill"}
coder_to_save.skills_manager.exclude_list = {"excluded_skill"}
coder_to_save.skills_manager.directory_paths = ["/test/skills/path/saved"]

session_name = "agent_session"
await commands_to_save.execute("save-session", session_name)

session_file = Path(handle_core_files(".cecli")) / "sessions" / f"{session_name}.json"
self.assertTrue(session_file.exists())

with open(session_file, "r", encoding="utf-8") as f:
saved_data = json.load(f)

# Assert saved data is correct
self.assertEqual(saved_data["mcps"], ["mock_mcp"])
self.assertEqual(saved_data["skills"]["skills_paths"], ["/test/skills/path/saved"])
self.assertEqual(saved_data["skills"]["skills_includelist"], ["included_skill"])
self.assertEqual(saved_data["skills"]["skills_excludelist"], ["excluded_skill"])
self.assertEqual(saved_data["tools"]["tools_paths"], ["/test/tools/path"])
self.assertEqual(saved_data["tools"]["tools_includelist"], ["included_tool"])
self.assertEqual(saved_data["tools"]["tools_excludelist"], ["excluded_tool"])

# === LOAD SESSION ===
# Create a new coder to load into, ensuring it's a clean slate
coder_to_load_initial = await Coder.create(
self.GPT35, "agent", io, args=mock_args, repo=mock.MagicMock()
)
commands_to_load = Commands(io, coder_to_load_initial, args=mock_args)

# Mock ToolRegistry.build_registry to check if it's called
with mock.patch(
"cecli.tools.utils.registry.ToolRegistry.build_registry"
) as mock_build_registry:
coder_after_load = None
try:
await commands_to_load.execute("load-session", session_name)
except SwitchCoderSignal as e:
# The SwitchCoderSignal is expected, we need to get the new coder from it
coder_after_load = await Coder.create(**e.kwargs)

self.assertIsNotNone(coder_after_load)

# Assert loaded state is correct in the new coder instance
connected_mcps = {s.name for s in coder_after_load.mcp_manager.connected_servers}
self.assertIn("mock_mcp", connected_mcps)

self.assertEqual(
coder_after_load.skills_manager.directory_paths, ["/test/skills/path/saved"]
)
self.assertEqual(coder_after_load.skills_manager.include_list, {"included_skill"})
self.assertEqual(coder_after_load.skills_manager.exclude_list, {"excluded_skill"})

self.assertEqual(coder_after_load.agent_config["tools_paths"], ["/test/tools/path"])
self.assertEqual(
coder_after_load.agent_config["tools_includelist"], ["included_tool"]
)
self.assertEqual(
coder_after_load.agent_config["tools_excludelist"], ["excluded_tool"]
)

# Assert that the tool registry was rebuilt
mock_build_registry.assert_called_with(agent_config=coder_after_load.agent_config)
Loading