Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ def _cancel_event():
shimmer.stop()
stream_buf.discard()
print_turn_complete()
print_plan()
session = session_holder[0] if session_holder else None
print_plan(session=session)
if session is not None:
await session.send_deferred_turn_complete_notification(event)
turn_complete_event.set()
Expand Down
13 changes: 8 additions & 5 deletions agent/tools/hf_repo_git_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,14 @@ async def _list_prs(self, args: Dict[str, Any]) -> ToolResult:
repo_type = args.get("repo_type", "model")
status = args.get("status", "all") # open, closed, all

discussions = list(self.api.get_repo_discussions(
repo_id=repo_id,
repo_type=repo_type,
discussion_status=status if status != "all" else None,
))
def _fetch():
return list(self.api.get_repo_discussions(
repo_id=repo_id,
repo_type=repo_type,
discussion_status=status if status != "all" else None,
))

discussions = await _async_call(_fetch)

if not discussions:
return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0}
Expand Down
23 changes: 15 additions & 8 deletions agent/tools/plan_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from .types import ToolResult

# In-memory storage for the current plan (raw structure from agent)
_current_plan: List[Dict[str, str]] = []
# Per-session plan storage to avoid cross-session corruption
_session_plans: Dict[str, List[Dict[str, str]]] = {}
_last_plan_session_id: str | None = None


class PlanTool:
Expand All @@ -26,7 +27,7 @@ async def execute(self, params: Dict[str, Any]) -> ToolResult:
Returns:
ToolResult with formatted output
"""
global _current_plan
global _session_plans, _last_plan_session_id

todos = params.get("todos", [])

Expand Down Expand Up @@ -54,8 +55,10 @@ async def execute(self, params: Dict[str, Any]) -> ToolResult:
"isError": True,
}

# Store the raw todos structure in memory
_current_plan = todos
# Store per-session to prevent cross-session plan corruption
session_id = self.session.session_id if self.session else "__no_session__"
_session_plans[session_id] = todos
_last_plan_session_id = session_id

# Emit plan update event if session is available
if self.session:
Expand All @@ -76,9 +79,13 @@ async def execute(self, params: Dict[str, Any]) -> ToolResult:
}


def get_current_plan() -> List[Dict[str, str]]:
"""Get the current plan (raw structure)."""
return _current_plan
def get_current_plan(session_id: str | None = None) -> List[Dict[str, str]]:
"""Get the current plan for a session (raw structure)."""
if session_id:
return _session_plans.get(session_id, [])
if _last_plan_session_id:
return _session_plans.get(_last_plan_session_id, [])
return []


# Tool specification
Expand Down
9 changes: 5 additions & 4 deletions agent/utils/terminal_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,12 @@ def print_help() -> None:

# ── Plan display ───────────────────────────────────────────────────────

def format_plan_display() -> str:
def format_plan_display(session=None) -> str:
"""Format the current plan for display."""
from agent.tools.plan_tool import get_current_plan

plan = get_current_plan()
session_id = session.session_id if session else None
plan = get_current_plan(session_id=session_id)
if not plan:
return ""

Expand All @@ -462,8 +463,8 @@ def format_plan_display() -> str:
return "\n".join(lines)


def print_plan() -> None:
plan_str = format_plan_display()
def print_plan(session=None) -> None:
plan_str = format_plan_display(session=session)
if plan_str:
_console.print(plan_str)

Expand Down
21 changes: 14 additions & 7 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,22 @@ async def lifespan(app: FastAPI):
lifespan=lifespan,
)

# CORS middleware for development
# CORS middleware
allow_origins = [
"http://localhost:5173", # Vite dev server
"http://localhost:3000",
"http://127.0.0.1:5173",
"http://127.0.0.1:3000",
]

# Add production HF Spaces URL when deployed
space_host = os.environ.get("SPACE_HOST")
if space_host:
allow_origins.append(f"https://{space_host}")

app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5173", # Vite dev server
"http://localhost:3000",
"http://127.0.0.1:5173",
"http://127.0.0.1:3000",
],
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
Expand Down
Loading