Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions openviking/resource/watch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def _check_permission(
task: WatchTask,
account_id: str,
user_id: str,
agent_id: str,
role: str,
) -> bool:
"""Check if user has permission to access/modify a task.
Expand All @@ -278,6 +279,7 @@ def _check_permission(
task: The task to check permission for
account_id: Requester's account ID
user_id: Requester's user ID
agent_id: Requester's agent ID
role: Requester's role (ROOT/ADMIN/USER)

Returns:
Expand All @@ -286,7 +288,7 @@ def _check_permission(
Notes:
- ROOT can access all tasks.
- ADMIN can access tasks within the same account.
- USER can only access tasks they created within the same account.
- USER can only access tasks they created within the same account and agent.
"""
role_value = (role or "").lower()
if role_value == "root":
Expand All @@ -298,7 +300,7 @@ def _check_permission(
if role_value == "admin":
return True

return task.user_id == user_id
return task.user_id == user_id and task.agent_id == agent_id

def _check_uri_conflict(
self, to_uri: Optional[str], exclude_task_id: Optional[str] = None
Expand Down Expand Up @@ -417,6 +419,7 @@ async def update_task(
summarize: Optional[bool] = None,
processor_kwargs: Optional[Dict[str, Any]] = None,
is_active: Optional[bool] = None,
agent_id: str = "default",
) -> WatchTask:
"""Update an existing monitoring task.

Expand All @@ -425,6 +428,7 @@ async def update_task(
account_id: Requester's account ID
user_id: Requester's user ID
role: Requester's role (ROOT/ADMIN/USER)
agent_id: Requester's agent ID
path: New resource path
to_uri: New target URI
parent_uri: New parent URI
Expand All @@ -446,9 +450,9 @@ async def update_task(
if not task:
raise ValueError(f"Task {task_id} not found")

if not self._check_permission(task, account_id, user_id, role):
if not self._check_permission(task, account_id, user_id, agent_id, role):
raise PermissionDeniedError(
f"User {account_id}/{user_id} does not have permission to update task {task_id}"
f"User {account_id}/{user_id}/{agent_id} does not have permission to update task {task_id}"
)

if self._check_uri_conflict(to_uri, exclude_task_id=task_id):
Expand Down Expand Up @@ -518,6 +522,7 @@ async def delete_task(
account_id: str,
user_id: str,
role: str,
agent_id: str = "default",
) -> bool:
"""Delete a monitoring task.

Expand All @@ -526,6 +531,7 @@ async def delete_task(
account_id: Requester's account ID
user_id: Requester's user ID
role: Requester's role (ROOT/ADMIN/USER)
agent_id: Requester's agent ID

Returns:
True if task was deleted, False if not found
Expand All @@ -538,9 +544,9 @@ async def delete_task(
if not task:
return False

if not self._check_permission(task, account_id, user_id, role):
if not self._check_permission(task, account_id, user_id, agent_id, role):
raise PermissionDeniedError(
f"User {account_id}/{user_id} does not have permission to delete task {task_id}"
f"User {account_id}/{user_id}/{agent_id} does not have permission to delete task {task_id}"
)

self._tasks.pop(task_id, None)
Expand All @@ -558,6 +564,7 @@ async def get_task(
account_id: str = "default",
user_id: str = "default",
role: str = "root",
agent_id: str = "default",
) -> Optional[WatchTask]:
"""Get a monitoring task by ID.

Expand All @@ -566,6 +573,7 @@ async def get_task(
account_id: Requester's account ID
user_id: Requester's user ID
role: Requester's role (ROOT/ADMIN/USER)
agent_id: Requester's agent ID

Returns:
WatchTask if found and accessible, None otherwise
Expand All @@ -575,7 +583,7 @@ async def get_task(
if not task:
return None

if not self._check_permission(task, account_id, user_id, role):
if not self._check_permission(task, account_id, user_id, agent_id, role):
return None

return task
Expand All @@ -586,13 +594,15 @@ async def get_all_tasks(
user_id: str,
role: str,
active_only: bool = False,
agent_id: str = "default",
) -> List[WatchTask]:
"""Get all monitoring tasks accessible by the user.

Args:
account_id: Requester's account ID
user_id: Requester's user ID
role: Requester's role (ROOT/ADMIN/USER)
agent_id: Requester's agent ID
active_only: If True, only return active tasks

Returns:
Expand All @@ -601,7 +611,7 @@ async def get_all_tasks(
async with self._lock:
tasks = []
for task in self._tasks.values():
if not self._check_permission(task, account_id, user_id, role):
if not self._check_permission(task, account_id, user_id, agent_id, role):
continue
if active_only and not task.is_active:
continue
Expand All @@ -614,6 +624,7 @@ async def get_task_by_uri(
account_id: str,
user_id: str,
role: str,
agent_id: str = "default",
) -> Optional[WatchTask]:
"""Get a monitoring task by target URI.

Expand All @@ -622,6 +633,7 @@ async def get_task_by_uri(
account_id: Requester's account ID
user_id: Requester's user ID
role: Requester's role (ROOT/ADMIN/USER)
agent_id: Requester's agent ID

Returns:
WatchTask if found and accessible, None otherwise
Expand All @@ -635,7 +647,7 @@ async def get_task_by_uri(
if not task:
return None

if not self._check_permission(task, account_id, user_id, role):
if not self._check_permission(task, account_id, user_id, agent_id, role):
return None

return task
Expand Down
1 change: 1 addition & 0 deletions openviking/resource/watch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ async def _execute_task(self, task) -> None:
account_id=task.account_id,
user_id=task.user_id,
role=getattr(task, "original_role", None) or Role.USER.value,
agent_id=task.agent_id,
is_active=False,
)
)
Expand Down
4 changes: 4 additions & 0 deletions openviking/service/resource_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ async def _handle_watch_task_creation(
account_id=ctx.account_id,
user_id=ctx.user.user_id,
role=ctx.role.value,
agent_id=ctx.user.agent_id,
)
if existing_task:
if existing_task.is_active:
Expand All @@ -296,6 +297,7 @@ async def _handle_watch_task_creation(
account_id=ctx.account_id,
user_id=ctx.user.user_id,
role=ctx.role.value,
agent_id=ctx.user.agent_id,
path=path,
to_uri=to_uri,
parent_uri=parent_uri,
Expand Down Expand Up @@ -344,13 +346,15 @@ async def _handle_watch_task_cancellation(self, to_uri: str, ctx: RequestContext
account_id=ctx.account_id,
user_id=ctx.user.user_id,
role=ctx.role.value,
agent_id=ctx.user.agent_id,
)
if existing_task:
await watch_manager.update_task(
task_id=existing_task.task_id,
account_id=ctx.account_id,
user_id=ctx.user.user_id,
role=ctx.role.value,
agent_id=ctx.user.agent_id,
is_active=False,
)
logger.info(
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_watch_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ async def get_watch_task(client: AsyncOpenViking, to_uri: str):
account_id=client._service.user.account_id,
user_id=client._service.user.user_id,
role=Role.USER.value,
agent_id=client._service.user.agent_id,
)


Expand Down
91 changes: 90 additions & 1 deletion tests/resource/test_watch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import pytest
import pytest_asyncio

from openviking.resource.watch_manager import WatchManager, WatchTask
from openviking.resource.watch_manager import PermissionDeniedError, WatchManager, WatchTask
from openviking_cli.exceptions import ConflictError
from tests.utils.mock_agfs import MockLocalAGFS

TEST_ACCOUNT_ID = "default"
TEST_USER_ID = "default"
TEST_AGENT_ID = "default"
OTHER_AGENT_ID = "other-agent"
TEST_ROLE = "ROOT"


Expand Down Expand Up @@ -489,6 +491,93 @@ async def test_get_next_execution_time(self, watch_manager: WatchManager):
next_time = await watch_manager.get_next_execution_time()
assert next_time is not None

@pytest.mark.asyncio
async def test_user_cannot_access_other_agent_task(self, watch_manager: WatchManager):
task = await watch_manager.create_task(
path="/test/path",
to_uri="viking://resources/agent-isolation",
account_id=TEST_ACCOUNT_ID,
user_id=TEST_USER_ID,
agent_id=TEST_AGENT_ID,
)

by_task_id = await watch_manager.get_task(
task.task_id,
account_id=TEST_ACCOUNT_ID,
user_id=TEST_USER_ID,
role="USER",
agent_id=OTHER_AGENT_ID,
)
by_uri = await watch_manager.get_task_by_uri(
to_uri="viking://resources/agent-isolation",
account_id=TEST_ACCOUNT_ID,
user_id=TEST_USER_ID,
role="USER",
agent_id=OTHER_AGENT_ID,
)
tasks = await watch_manager.get_all_tasks(
account_id=TEST_ACCOUNT_ID,
user_id=TEST_USER_ID,
role="USER",
agent_id=OTHER_AGENT_ID,
)

assert by_task_id is None
assert by_uri is None
assert tasks == []

@pytest.mark.asyncio
async def test_user_cannot_update_or_delete_other_agent_task(self, watch_manager: WatchManager):
task = await watch_manager.create_task(
path="/test/path",
to_uri="viking://resources/agent-update-delete",
account_id=TEST_ACCOUNT_ID,
user_id=TEST_USER_ID,
agent_id=TEST_AGENT_ID,
)

with pytest.raises(PermissionDeniedError, match="does not have permission"):
await watch_manager.update_task(
task_id=task.task_id,
account_id=TEST_ACCOUNT_ID,
user_id=TEST_USER_ID,
role="USER",
agent_id=OTHER_AGENT_ID,
reason="other agent should not update",
)

with pytest.raises(PermissionDeniedError, match="does not have permission"):
await watch_manager.delete_task(
task_id=task.task_id,
account_id=TEST_ACCOUNT_ID,
user_id=TEST_USER_ID,
role="USER",
agent_id=OTHER_AGENT_ID,
)

@pytest.mark.asyncio
async def test_admin_can_manage_other_agent_task_in_same_account(
self, watch_manager: WatchManager
):
task = await watch_manager.create_task(
path="/test/path",
to_uri="viking://resources/admin-cross-agent",
account_id=TEST_ACCOUNT_ID,
user_id=TEST_USER_ID,
agent_id=TEST_AGENT_ID,
)

updated = await watch_manager.update_task(
task_id=task.task_id,
account_id=TEST_ACCOUNT_ID,
user_id="admin-user",
role="ADMIN",
agent_id=OTHER_AGENT_ID,
reason="admin update",
)

assert updated.reason == "admin update"


class TestWatchManagerPersistence:
"""Tests for WatchManager persistence."""
Expand Down
Loading
Loading