diff --git a/.gitignore b/.gitignore index d2608d50..a6b5e583 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ share/python-wheels/ *.egg MANIFEST notebooks/ +scripts/ # PyInstaller # Usually these files are written by a python script from a template diff --git a/aixplain/v2/__init__.py b/aixplain/v2/__init__.py index ebc70098..af50b70c 100644 --- a/aixplain/v2/__init__.py +++ b/aixplain/v2/__init__.py @@ -17,6 +17,7 @@ EvaluatorConfig, EditorConfig, ) +from .session import Session, SessionMessage, SessionMessageAttachment from .meta_agents import Debugger, DebugResult from .agent_progress import AgentProgressTracker, ProgressFormat from .api_key import APIKey, APIKeyLimits, APIKeyUsageLimit, TokenType @@ -48,6 +49,11 @@ EvolveType, CodeInterpreterModel, SplittingOptions, + SessionStatus, + RunStatus, + MessageRole, + Reaction, + AttachmentType, ) __all__ = [ @@ -59,6 +65,10 @@ "FileUploader", "upload_file", "validate_file_for_upload", + # Session classes + "Session", + "SessionMessage", + "SessionMessageAttachment", # Inspector classes "Inspector", "InspectorTarget", @@ -111,4 +121,9 @@ "EvolveType", "CodeInterpreterModel", "SplittingOptions", + "SessionStatus", + "RunStatus", + "MessageRole", + "Reaction", + "AttachmentType", ] diff --git a/aixplain/v2/agent.py b/aixplain/v2/agent.py index 356a7020..fbb40b4f 100644 --- a/aixplain/v2/agent.py +++ b/aixplain/v2/agent.py @@ -38,10 +38,14 @@ class ConversationMessage(TypedDict): Attributes: role: The role of the message sender, either 'user' or 'assistant' content: The text content of the message + attachments: Optional pre-built attachment dicts (url, name, type) + files: Optional local file paths to upload and attach """ role: Literal["user", "assistant"] content: str + attachments: NotRequired[Optional[List[Dict[str, Any]]]] + files: NotRequired[Optional[List[Any]]] def validate_history(history: List[Dict[str, Any]]) -> bool: @@ -969,74 +973,72 @@ def build_run_payload(self, **kwargs: Unpack[AgentRunParams]) -> dict: return payload - def generate_session_id(self, history: Optional[List[ConversationMessage]] = None) -> str: - """Generate a unique session ID for agent conversations. - - Creates a unique session identifier based on the agent ID and current timestamp. - If conversation history is provided, it attempts to initialize the session on the - server to enable context-aware conversations. + def create_session( + self, + name: Optional[str] = None, + history: Optional[List[ConversationMessage]] = None, + ) -> "Session": + """Create a new backend-managed session for this agent. Args: - history: Previous conversation history. Each message should contain - 'role' (either 'user' or 'assistant') and 'content' keys. - Defaults to None. + name: Optional human-readable name for the session. + history: Optional conversation history to seed the session with. + Each message must have 'role' and 'content' keys. + Messages may also include optional 'attachments' + (pre-built dicts with url/name/type) and/or 'files' + (local file paths to upload). Returns: - str: A unique session identifier in the format "{agent_id}_{timestamp}". + Session: The created Session instance, pre-populated with + history messages when provided. Raises: - ValueError: If the history format is invalid. + ValueError: If the agent has not been saved yet or if history + format is invalid. Example: - >>> agent = Agent.get("my_agent_id") - >>> session_id = agent.generate_session_id() - >>> # Or with history - >>> history = [ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"} - ... ] - >>> session_id = agent.generate_session_id(history=history) + >>> session = agent.create_session( + ... name="My Chat", + ... history=[ + ... {"role": "user", "content": "Analyze this", + ... "files": ["/tmp/data.csv"]}, + ... {"role": "assistant", "content": "Here are the results..."}, + ... ], + ... ) """ if not self.id: - self.save(as_draft=True) + raise ValueError("Agent must be saved before creating a session. Call agent.save() first.") if history: validate_history(history) - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - session_id = f"{self.id}_{timestamp}" - - if not history: - return session_id - - try: - # Use the existing run infrastructure to initialize the session - result = self.run_async( - query="/", - session_id=session_id, - history=history, - execution_params={ - "max_tokens": 2048, - "max_iterations": 10, - "output_format": OutputFormat.TEXT.value, - "expected_output": None, - }, - allow_history_and_session_id=True, - ) + session = self.context.Session(agent_id=self.id, name=name) + session.save() + + if history: + for message in history: + session.add_message( + role=message["role"], + content=message["content"], + attachments=message.get("attachments"), + files=message.get("files"), + ) - # If we got a polling URL, poll for completion - if result.url and not result.completed: - final_result = self.sync_poll(result.url, timeout=300, wait_time=0.5) + return session - if final_result.status == ResponseStatus.SUCCESS: - return session_id - else: - logging.error(f"Session {session_id} initialization failed: {final_result}") - return session_id - else: - # Direct completion or no polling needed - return session_id + def list_sessions(self, status: Optional[str] = None) -> list: + """List sessions for this agent. + + Args: + status: Optional status filter (e.g. "active", "completed"). + + Returns: + List of Session instances belonging to this agent. + + Raises: + ValueError: If the agent has not been saved yet. + """ + if not self.id: + raise ValueError("Agent must be saved before listing sessions. Call agent.save() first.") - except Exception as e: - logging.error(f"Failed to initialize session {session_id}: {e}") - return session_id + return self.context.Session.list(agent_id=self.id, status=status) diff --git a/aixplain/v2/core.py b/aixplain/v2/core.py index bb42a295..17df180b 100644 --- a/aixplain/v2/core.py +++ b/aixplain/v2/core.py @@ -13,6 +13,7 @@ from .inspector import Inspector from .meta_agents import Debugger from .api_key import APIKey +from .session import Session from . import enums @@ -25,6 +26,7 @@ InspectorType = TypeVar("InspectorType", bound=Inspector) DebuggerType = TypeVar("DebuggerType", bound=Debugger) APIKeyType = TypeVar("APIKeyType", bound=APIKey) +SessionType = TypeVar("SessionType", bound=Session) class Aixplain: @@ -49,6 +51,7 @@ class Aixplain: Inspector: InspectorType = None Debugger: DebuggerType = None APIKey: APIKeyType = None + Session: SessionType = None Function = enums.Function Supplier = enums.Supplier @@ -66,6 +69,12 @@ class Aixplain: SortOrder = enums.SortOrder StorageType = enums.StorageType + SessionStatus = enums.SessionStatus + RunStatus = enums.RunStatus + MessageRole = enums.MessageRole + Reaction = enums.Reaction + AttachmentType = enums.AttachmentType + BACKEND_URL = "https://platform-api.aixplain.com" BENCHMARKS_BACKEND_URL = "https://platform-api.aixplain.com" MODELS_RUN_URL = "https://models.aixplain.com/api/v2/execute" @@ -121,3 +130,4 @@ def init_resources(self) -> None: self.Inspector = type("Inspector", (Inspector,), {"context": self}) self.Debugger = type("Debugger", (Debugger,), {"context": self}) self.APIKey = type("APIKey", (APIKey,), {"context": self}) + self.Session = type("Session", (Session,), {"context": self}) diff --git a/aixplain/v2/enums.py b/aixplain/v2/enums.py index 7b9f2c81..77b36a8e 100644 --- a/aixplain/v2/enums.py +++ b/aixplain/v2/enums.py @@ -247,6 +247,49 @@ class SplittingOptions(str, Enum): LINE = "line" +class SessionStatus(str, Enum): + """Session status values.""" + + ACTIVE = "active" + COMPLETED = "completed" + FAILED = "failed" + ARCHIVED = "archived" + + +class RunStatus(str, Enum): + """Run status values for sessions.""" + + IDLE = "idle" + RUNNING = "running" + COMPLETED = "completed" + + +class MessageRole(str, Enum): + """Message role in a session conversation.""" + + USER = "user" + ASSISTANT = "assistant" + + +class Reaction(str, Enum): + """Reaction types for session messages.""" + + LIKE = "LIKE" + DISLIKE = "DISLIKE" + + +class AttachmentType(str, Enum): + """Attachment type for session message attachments.""" + + TEXT = "text" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + DOCUMENT = "document" + CODE = "code" + UNKNOWN = "unknown" + + __all__ = [ "AuthenticationScheme", "FileType", @@ -268,4 +311,9 @@ class SplittingOptions(str, Enum): "CodeInterpreterModel", "DataType", "SplittingOptions", + "SessionStatus", + "RunStatus", + "MessageRole", + "Reaction", + "AttachmentType", ] diff --git a/aixplain/v2/session.py b/aixplain/v2/session.py new file mode 100644 index 00000000..fd621b63 --- /dev/null +++ b/aixplain/v2/session.py @@ -0,0 +1,390 @@ +"""Session module for aiXplain v2 SDK.""" + +import os +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union +from pathlib import Path + +from dataclasses_json import dataclass_json, config + +from .enums import AttachmentType +from .exceptions import APIError, ResourceError +from .resource import ( + BaseResource, + GetResourceMixin, + DeleteResourceMixin, + BaseGetParams, + BaseDeleteParams, +) + +logger = logging.getLogger(__name__) + + +def _mime_to_attachment_type(mime_type: str) -> str: + """Map a MIME type string to an AttachmentType value.""" + if not mime_type: + return AttachmentType.UNKNOWN.value + + main_type = mime_type.split("/")[0] if "/" in mime_type else "" + sub_type = mime_type.split("/")[1] if "/" in mime_type else "" + + # Code types + code_subtypes = { + "x-python", + "javascript", + "x-javascript", + "typescript", + "x-java-source", + "x-c", + "x-c++", + "x-ruby", + "x-go", + "x-rust", + "x-shellscript", + } + if sub_type in code_subtypes: + return AttachmentType.CODE.value + + # Document types + document_subtypes = { + "pdf", + "msword", + "vnd.openxmlformats-officedocument.wordprocessingml.document", + "vnd.ms-excel", + "vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "vnd.ms-powerpoint", + "vnd.openxmlformats-officedocument.presentationml.presentation", + } + if sub_type in document_subtypes: + return AttachmentType.DOCUMENT.value + + # Main type mapping + type_map = { + "image": AttachmentType.IMAGE.value, + "video": AttachmentType.VIDEO.value, + "audio": AttachmentType.AUDIO.value, + "text": AttachmentType.TEXT.value, + } + if main_type in type_map: + return type_map[main_type] + + return AttachmentType.UNKNOWN.value + + +def _parse_list_response(response: Any, item_type: str) -> list: + """Parse a response that should be a bare JSON array. + + Args: + response: The API response (should be a list). + item_type: Description of items for error messages. + + Returns: + The response as a list. + + Raises: + ResourceError: If response is not a list. + """ + if isinstance(response, list): + return response + raise ResourceError( + f"Expected a list of {item_type} from the API, got {type(response).__name__}: {str(response)[:200]}" + ) + + +def _deserialize(cls, data: dict, description: str): + """Safely deserialize a dict into a dataclass. + + Args: + cls: The dataclass type with from_dict. + data: The dict to deserialize. + description: Description for error messages. + + Returns: + The deserialized instance. + + Raises: + ResourceError: If deserialization fails. + """ + try: + return cls.from_dict(data) + except Exception as e: + raise ResourceError(f"Failed to parse {description}: {e}. Response data: {str(data)[:200]}") + + +@dataclass_json +@dataclass +class SessionMessageAttachment: + """Attachment on a session message.""" + + url: str + name: Optional[str] = None + type: Optional[str] = None + + +@dataclass_json +@dataclass +class SessionMessage: + """A message within a session (not a resource — all ops go through Session).""" + + id: str = "" + session_id: str = field(default="", metadata=config(field_name="sessionId")) + user_id: str = field(default="", metadata=config(field_name="userId")) + agent_id: str = field(default="", metadata=config(field_name="agentId")) + role: str = "" + content: str = "" + sequence: int = 0 + request_id: Optional[str] = field(default=None, metadata=config(field_name="requestId")) + reaction: Optional[str] = None + attachments: Optional[List[SessionMessageAttachment]] = None + created_at: str = field(default="", metadata=config(field_name="createdAt")) + + +@dataclass_json +@dataclass(repr=False) +class Session( + BaseResource, + GetResourceMixin[BaseGetParams, "Session"], + DeleteResourceMixin[BaseDeleteParams, "Session"], +): + """Session resource for managing agent conversation sessions.""" + + RESOURCE_PATH = "v1/sessions" + + user_id: str = field(default="", metadata=config(field_name="userId")) + agent_id: str = field(default="", metadata=config(field_name="agentId")) + name: Optional[str] = None + status: str = "active" + run_status: str = field(default="", metadata=config(field_name="runStatus")) + message_count: int = field(default=0, metadata=config(field_name="messageCount")) + last_message_preview: Optional[str] = field(default=None, metadata=config(field_name="lastMessagePreview")) + last_message_at: Optional[str] = field(default=None, metadata=config(field_name="lastMessageAt")) + created_at: str = field(default="", metadata=config(field_name="createdAt")) + updated_at: str = field(default="", metadata=config(field_name="updatedAt")) + + def build_save_payload(self, **kwargs: Any) -> dict: + """Build payload with only mutable fields.""" + payload = {} + if self.agent_id: + payload["agentId"] = self.agent_id + if self.name is not None: + payload["name"] = self.name + if self.status: + payload["status"] = self.status + return payload + + @classmethod + def list( + cls, + agent_id: Optional[str] = None, + status: Optional[str] = None, + user_id: Optional[str] = None, + ) -> List["Session"]: + """List sessions with optional filters. + + Args: + agent_id: Filter by agent ID. + status: Filter by session status. + user_id: Filter by user ID. + + Returns: + List of Session instances. + + Raises: + ResourceError: If the API response is not a list or + deserialization fails. + APIError: If the API request fails. + """ + context = getattr(cls, "context", None) + if context is None: + raise ResourceError("Context is required for resource operations") + + params: Dict[str, str] = {} + if agent_id is not None: + params["agentId"] = agent_id + if status is not None: + params["status"] = status + if user_id is not None: + params["userId"] = user_id + + try: + response = context.client.request("get", cls.RESOURCE_PATH, params=params) + except APIError: + raise + except Exception as e: + raise ResourceError(f"Failed to list sessions: {e}") + + items = _parse_list_response(response, "sessions") + sessions = [] + for item in items: + session = _deserialize(cls, item, "session") + session.context = context + session._update_saved_state() + sessions.append(session) + return sessions + + def messages(self) -> List[SessionMessage]: + """Get all messages in this session. + + Returns: + List of SessionMessage instances. + + Raises: + ResourceError: If the API response is not a list or + deserialization fails. + APIError: If the API request fails. + """ + self._ensure_valid_state() + path = f"{self.RESOURCE_PATH}/{self.encoded_id}/messages" + + try: + response = self.context.client.request("get", path) + except APIError: + raise + except Exception as e: + raise ResourceError(f"Failed to list messages for session '{self.id}': {e}") + + items = _parse_list_response(response, "messages") + return [_deserialize(SessionMessage, item, "session message") for item in items] + + def add_message( + self, + role: str, + content: str, + request_id: Optional[str] = None, + attachments: Optional[List[Dict[str, Any]]] = None, + files: Optional[List[Union[str, Path]]] = None, + ) -> SessionMessage: + """Add a message to this session. + + Args: + role: Message role ("user" or "assistant"). + content: Message content. + request_id: Optional request ID to associate with the message. + attachments: Pre-built attachment dicts with url, name, type keys. + files: Local file paths to upload and attach. + + Returns: + The created SessionMessage. + + Raises: + ResourceError: If the operation fails. + APIError: If the API request fails. + FileUploadError: If a file upload fails. + """ + self._ensure_valid_state() + + all_attachments = list(attachments) if attachments else [] + + if files: + from .upload_utils import FileUploader, MimeTypeDetector + + uploader = FileUploader( + backend_url=self.context.backend_url, + api_key=self.context.api_key, + ) + for file_path in files: + file_path_str = str(file_path) + try: + download_url = uploader.upload( + file_path_str, + is_temp=True, + return_download_link=True, + ) + except Exception as e: + raise ResourceError(f"Failed to upload file '{file_path_str}' for session '{self.id}': {e}") + mime_type = MimeTypeDetector.detect_mime_type(file_path_str) + att_type = _mime_to_attachment_type(mime_type) + all_attachments.append( + { + "url": download_url, + "name": os.path.basename(file_path_str), + "type": att_type, + } + ) + + payload: Dict[str, Any] = {"role": role, "content": content} + if request_id is not None: + payload["requestId"] = request_id + if all_attachments: + payload["attachments"] = all_attachments + + path = f"{self.RESOURCE_PATH}/{self.encoded_id}/messages" + try: + response = self.context.client.request("post", path, json=payload) + except APIError: + raise + except Exception as e: + raise ResourceError(f"Failed to add message to session '{self.id}': {e}") + + return _deserialize(SessionMessage, response, "session message") + + def get_message(self, message_id: str) -> SessionMessage: + """Get a specific message by ID. + + Args: + message_id: The message ID. + + Returns: + The SessionMessage. + + Raises: + ResourceError: If deserialization fails. + APIError: If the API request fails (e.g., message not found). + """ + self._ensure_valid_state() + path = f"{self.RESOURCE_PATH}/{self.encoded_id}/messages/{message_id}" + try: + response = self.context.client.request("get", path) + except APIError: + raise + except Exception as e: + raise ResourceError(f"Failed to get message '{message_id}' from session '{self.id}': {e}") + return _deserialize(SessionMessage, response, "session message") + + def delete_message(self, message_id: str) -> None: + """Delete a message from this session. + + Args: + message_id: The message ID to delete. + + Raises: + APIError: If the API request fails (e.g., message not found). + ResourceError: If the session is in an invalid state. + """ + self._ensure_valid_state() + path = f"{self.RESOURCE_PATH}/{self.encoded_id}/messages/{message_id}" + try: + self.context.client.request_raw("delete", path) + except APIError: + raise + except Exception as e: + raise ResourceError(f"Failed to delete message '{message_id}' from session '{self.id}': {e}") + + def react(self, message_id: str, reaction: Optional[str]) -> SessionMessage: + """React to a message or clear a reaction. + + Only assistant messages can be reacted to. + + Args: + message_id: The message ID to react to. + reaction: "LIKE", "DISLIKE", or None to clear. + + Returns: + The updated SessionMessage. + + Raises: + APIError: If the API request fails (e.g., reacting to a + non-assistant message). + ResourceError: If deserialization fails. + """ + self._ensure_valid_state() + path = f"{self.RESOURCE_PATH}/{self.encoded_id}/messages/{message_id}/reaction" + payload: Dict[str, Any] = {"reaction": reaction} + try: + response = self.context.client.request("post", path, json=payload) + except APIError: + raise + except Exception as e: + raise ResourceError(f"Failed to react to message '{message_id}' in session '{self.id}': {e}") + return _deserialize(SessionMessage, response, "session message") diff --git a/tests/functional/v2/test_session.py b/tests/functional/v2/test_session.py new file mode 100644 index 00000000..6e8ba51b --- /dev/null +++ b/tests/functional/v2/test_session.py @@ -0,0 +1,347 @@ +"""Functional tests for v2 Session management. + +These tests run against a real backend and require valid credentials. +Set TEAM_API_KEY (or AIXPLAIN_API_KEY) in the environment. + +A test agent is created once per module and cleaned up afterwards. +""" + +import os +import tempfile +import time + +import pytest + +from aixplain.v2 import Session, SessionMessage, SessionStatus +from aixplain.v2.exceptions import APIError + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def test_agent(client): + """Create a temporary agent for session tests.""" + agent = client.Agent( + name=f"Session Test Agent {int(time.time())}", + description="Temporary agent for session functional tests", + instructions="You are a helpful test agent. Keep responses short.", + ) + agent.save() + yield agent + try: + agent.delete() + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Session CRUD +# --------------------------------------------------------------------------- + + +class TestSessionCRUD: + """End-to-end session create / get / list / update / delete.""" + + def test_create_session_via_agent(self, client, test_agent): + """Creating a session through an agent should return a saved Session.""" + session = test_agent.create_session(name="Func Test Session") + + assert session.id is not None + assert isinstance(session, Session) + assert session.agent_id == test_agent.id + assert session.name == "Func Test Session" + assert session.status == "active" + + # Cleanup + try: + session.delete() + except Exception: + pass + + def test_create_session_directly(self, client, test_agent): + """Creating a Session instance directly and calling .save().""" + session = client.Session(agent_id=test_agent.id, name="Direct Create") + session.save() + + assert session.id is not None + assert session.agent_id == test_agent.id + + # Cleanup + try: + session.delete() + except Exception: + pass + + def test_get_session(self, client, test_agent): + """Retrieving a session by ID should return the same session.""" + session = test_agent.create_session(name="Get Test") + fetched = client.Session.get(session.id) + + assert fetched.id == session.id + assert fetched.name == session.name + assert fetched.agent_id == session.agent_id + + # Cleanup + try: + session.delete() + except Exception: + pass + + def test_list_sessions_for_agent(self, client, test_agent): + """Listing sessions for an agent should include the created session.""" + session = test_agent.create_session(name="List Test") + + sessions = test_agent.list_sessions() + + assert isinstance(sessions, list) + session_ids = [s.id for s in sessions] + assert session.id in session_ids + + # Cleanup + try: + session.delete() + except Exception: + pass + + def test_list_sessions_with_status_filter(self, client, test_agent): + """Filtering by status should only return matching sessions.""" + session = test_agent.create_session(name="Status Filter Test") + + active = test_agent.list_sessions(status="active") + assert all(s.status == "active" for s in active) + + # Cleanup + try: + session.delete() + except Exception: + pass + + def test_update_session(self, client, test_agent): + """Updating session name via save() should persist the change.""" + session = test_agent.create_session(name="Before Update") + session.name = "After Update" + session.save() + + fetched = client.Session.get(session.id) + assert fetched.name == "After Update" + + # Cleanup + try: + session.delete() + except Exception: + pass + + def test_create_session_with_history(self, client, test_agent): + """Creating a session with history should seed it with messages.""" + history = [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ] + session = test_agent.create_session(name="History Test", history=history) + + assert session.id is not None + messages = session.messages() + assert len(messages) >= 2 + contents = [m.content for m in messages] + assert "What is 2+2?" in contents + assert "4" in contents + + # Cleanup + try: + session.delete() + except Exception: + pass + + def test_delete_session(self, client, test_agent): + """Deleting a session should succeed.""" + session = test_agent.create_session(name="Delete Me") + session_id = session.id + assert session_id is not None + + result = session.delete() + assert result.completed is True + + +# --------------------------------------------------------------------------- +# Session messages +# --------------------------------------------------------------------------- + + +class TestSessionMessages: + """End-to-end tests for session message operations.""" + + @pytest.fixture() + def session(self, client, test_agent): + """Create a session for message tests and clean up after.""" + s = test_agent.create_session(name=f"Msg Test {int(time.time())}") + yield s + try: + s.delete() + except Exception: + pass + + def test_add_and_get_message(self, session): + """Adding a message should return a SessionMessage with content.""" + msg = session.add_message(role="user", content="Hello from test!") + + assert isinstance(msg, SessionMessage) + assert msg.id is not None + assert msg.content == "Hello from test!" + assert msg.role == "user" + + def test_list_messages(self, session): + """After adding messages, messages() should return them.""" + session.add_message(role="user", content="First message") + session.add_message(role="assistant", content="Second message") + + messages = session.messages() + + assert isinstance(messages, list) + assert len(messages) >= 2 + contents = [m.content for m in messages] + assert "First message" in contents + assert "Second message" in contents + + def test_get_single_message(self, session): + """get_message() should return the specific message.""" + created = session.add_message(role="user", content="Specific message") + + fetched = session.get_message(created.id) + + assert fetched.id == created.id + assert fetched.content == "Specific message" + + def test_delete_message(self, session): + """delete_message() should remove the message.""" + msg = session.add_message(role="user", content="Delete this message") + session.delete_message(msg.id) + + remaining = session.messages() + remaining_ids = [m.id for m in remaining] + assert msg.id not in remaining_ids + + def test_add_message_with_url_attachments(self, session): + """Adding a message with explicit URL attachments should include them.""" + attachments = [ + {"url": "https://example.com/test.png", "name": "test.png", "type": "image"}, + ] + msg = session.add_message( + role="user", + content="See attached", + attachments=attachments, + ) + + assert msg.id is not None + if msg.attachments: + assert msg.attachments[0].url == "https://example.com/test.png" + assert msg.attachments[0].name == "test.png" + assert msg.attachments[0].type == "image" + + def test_add_message_with_file_upload(self, session): + """Uploading a local file via add_message(files=...) should attach it.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", prefix="session_test_", delete=False) as f: + f.write("Test content for file upload verification.\n") + tmp_path = f.name + + try: + msg = session.add_message( + role="user", + content="Please review this file", + files=[tmp_path], + ) + + assert msg.id is not None + assert msg.attachments is not None, "Backend should echo attachments" + assert len(msg.attachments) == 1 + + att = msg.attachments[0] + assert att.url is not None and att.url.startswith("http") + assert att.name == os.path.basename(tmp_path) + assert att.type == "text" + + # Verify the attachment persists when re-fetching + fetched = session.get_message(msg.id) + assert fetched.attachments is not None + assert len(fetched.attachments) == 1 + finally: + os.remove(tmp_path) + + def test_react_like_and_dislike(self, session): + """Reacting to an assistant message with LIKE then DISLIKE should work.""" + session.add_message(role="user", content="Say something") + msg = session.add_message(role="assistant", content="Here is my response") + + liked = session.react(msg.id, "LIKE") + assert liked.reaction == "LIKE" + + disliked = session.react(msg.id, "DISLIKE") + assert disliked.reaction == "DISLIKE" + + def test_clear_reaction(self, session): + """Passing None to react() should clear the reaction.""" + session.add_message(role="user", content="Say something") + msg = session.add_message(role="assistant", content="Clear reaction response") + session.react(msg.id, "DISLIKE") + + cleared = session.react(msg.id, None) + assert cleared.reaction is None + + +# --------------------------------------------------------------------------- +# Agent run + session interaction +# --------------------------------------------------------------------------- + + +class TestSessionWithAgentRun: + """Tests for how agent.run() interacts with sessions.""" + + def test_run_with_session_adds_messages(self, client, test_agent): + """Running an agent with a session_id should add messages to the session.""" + session = test_agent.create_session(name="Run Test") + msgs_before = session.messages() + + result = test_agent.run( + "What is the capital of France?", + session_id=session.id, + ) + assert result.status == "SUCCESS" + + msgs_after = session.messages() + assert len(msgs_after) > len(msgs_before), "Backend should auto-add messages to the session during a run" + + # Cleanup + try: + session.delete() + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestSessionErrors: + """Tests that backend errors are raised gracefully.""" + + def test_get_nonexistent_session(self, client): + """Getting a session that doesn't exist should raise APIError.""" + with pytest.raises(APIError): + client.Session.get("nonexistent-session-id-12345") + + def test_react_to_user_message_raises_error(self, client, test_agent): + """Reacting to a user message should raise an APIError.""" + session = test_agent.create_session(name="React Error Test") + msg = session.add_message(role="user", content="Can't like this") + + with pytest.raises(APIError, match="assistant"): + session.react(msg.id, "LIKE") + + # Cleanup + try: + session.delete() + except Exception: + pass diff --git a/tests/functional/v2/test_snake_case_e2e.py b/tests/functional/v2/test_snake_case_e2e.py index 12d2ddf2..778bb442 100644 --- a/tests/functional/v2/test_snake_case_e2e.py +++ b/tests/functional/v2/test_snake_case_e2e.py @@ -175,13 +175,20 @@ class TestAgentRunParamsKwargs: def test_session_id(self, client, test_agent): """session_id (renamed from sessionId) is sent to the backend and reflected in the response.""" agent = client.Agent.get(test_agent.id) - sid = agent.generate_session_id() + session = agent.create_session(name="snake_case_test") + sid = session.id - response = agent.run("ping", session_id=sid) + try: + response = agent.run("ping", session_id=sid) - assert response.status == "SUCCESS" - assert response.data.session_id is not None - assert sid in response.data.session_id + assert response.status == "SUCCESS" + assert response.data.session_id is not None + assert sid in response.data.session_id + finally: + try: + session.delete() + except Exception: + pass def test_execution_params(self, client, test_agent): """execution_params (renamed from executionParams) reaches the backend. @@ -210,13 +217,20 @@ def test_run_response_generation(self, client, test_agent): def test_allow_history_and_session_id(self, client, test_agent): """allow_history_and_session_id (renamed from allowHistoryAndSessionId) is accepted.""" agent = client.Agent.get(test_agent.id) - sid = agent.generate_session_id() + session = agent.create_session(name="history_test") + sid = session.id - response = agent.run( - "ping", - session_id=sid, - history=[{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}], - allow_history_and_session_id=True, - ) + try: + response = agent.run( + "ping", + session_id=sid, + history=[{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}], + allow_history_and_session_id=True, + ) - assert response.status == "SUCCESS" + assert response.status == "SUCCESS" + finally: + try: + session.delete() + except Exception: + pass diff --git a/tests/unit/v2/test_session.py b/tests/unit/v2/test_session.py new file mode 100644 index 00000000..d562f4b8 --- /dev/null +++ b/tests/unit/v2/test_session.py @@ -0,0 +1,1010 @@ +"""Unit tests for the v2 Session module. + +This module tests Session, SessionMessage, and SessionMessageAttachment +dataclasses, serialization, CRUD operations, and the MIME-to-AttachmentType +mapping — all with mocked HTTP calls. +""" + +from unittest.mock import Mock, patch, call + +import pytest + +from aixplain.v2.session import ( + Session, + SessionMessage, + SessionMessageAttachment, + _mime_to_attachment_type, +) +from aixplain.v2.enums import ( + SessionStatus, + RunStatus, + MessageRole, + Reaction, + AttachmentType, +) +from aixplain.v2.agent import Agent +from aixplain.v2.exceptions import ValidationError, APIError, ResourceError + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_context(**overrides): + """Create a mock Aixplain context with a mock client.""" + client = Mock() + ctx = Mock( + client=client, + backend_url="https://platform-api.aixplain.com", + api_key="test_key", + ) + for k, v in overrides.items(): + setattr(ctx, k, v) + return ctx + + +def _bound_session_class(ctx=None): + """Return a Session subclass bound to a mock context.""" + ctx = ctx or _make_mock_context() + + class BoundSession(Session): + context = ctx + + return BoundSession + + +def _bound_agent_class(ctx=None): + """Return an Agent subclass bound to a mock context.""" + ctx = ctx or _make_mock_context() + + class BoundAgent(Agent): + context = ctx + + return BoundAgent + + +# Sample API payloads --------------------------------------------------- + +SAMPLE_SESSION_DICT = { + "id": "sess_001", + "userId": "user_42", + "agentId": "agent_99", + "name": "My Session", + "status": "active", + "runStatus": "idle", + "messageCount": 3, + "lastMessagePreview": "Hello there", + "lastMessageAt": "2025-06-01T12:00:00Z", + "createdAt": "2025-06-01T10:00:00Z", + "updatedAt": "2025-06-01T12:00:00Z", +} + +SAMPLE_MESSAGE_DICT = { + "id": "msg_001", + "sessionId": "sess_001", + "userId": "user_42", + "agentId": "agent_99", + "role": "user", + "content": "Hello!", + "sequence": 1, + "requestId": "req_001", + "reaction": None, + "attachments": None, + "createdAt": "2025-06-01T10:01:00Z", +} + + +# ========================================================================= +# Enum tests +# ========================================================================= + + +class TestSessionEnums: + """Verify new enums have expected members and values.""" + + def test_session_status_values(self): + assert SessionStatus.ACTIVE == "active" + assert SessionStatus.COMPLETED == "completed" + assert SessionStatus.FAILED == "failed" + assert SessionStatus.ARCHIVED == "archived" + + def test_run_status_values(self): + assert RunStatus.IDLE == "idle" + assert RunStatus.RUNNING == "running" + assert RunStatus.COMPLETED == "completed" + + def test_message_role_values(self): + assert MessageRole.USER == "user" + assert MessageRole.ASSISTANT == "assistant" + + def test_reaction_values(self): + assert Reaction.LIKE == "LIKE" + assert Reaction.DISLIKE == "DISLIKE" + + def test_attachment_type_values(self): + assert AttachmentType.TEXT == "text" + assert AttachmentType.IMAGE == "image" + assert AttachmentType.VIDEO == "video" + assert AttachmentType.AUDIO == "audio" + assert AttachmentType.DOCUMENT == "document" + assert AttachmentType.CODE == "code" + assert AttachmentType.UNKNOWN == "unknown" + + +# ========================================================================= +# MIME type mapping +# ========================================================================= + + +class TestMimeToAttachmentType: + """Tests for the _mime_to_attachment_type helper.""" + + @pytest.mark.parametrize( + "mime, expected", + [ + ("image/png", "image"), + ("image/jpeg", "image"), + ("video/mp4", "video"), + ("audio/mpeg", "audio"), + ("audio/wav", "audio"), + ("text/plain", "text"), + ("text/csv", "text"), + ], + ) + def test_main_type_mapping(self, mime, expected): + assert _mime_to_attachment_type(mime) == expected + + @pytest.mark.parametrize( + "mime", + [ + "text/x-python", + "application/javascript", + "application/x-javascript", + "application/typescript", + ], + ) + def test_code_types(self, mime): + assert _mime_to_attachment_type(mime) == "code" + + @pytest.mark.parametrize( + "mime", + [ + "application/pdf", + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ], + ) + def test_document_types(self, mime): + assert _mime_to_attachment_type(mime) == "document" + + def test_unknown_fallback(self): + assert _mime_to_attachment_type("application/octet-stream") == "unknown" + + def test_empty_string(self): + assert _mime_to_attachment_type("") == "unknown" + + def test_none_like_empty(self): + # None is not str, but the function guards against falsy values + assert _mime_to_attachment_type("") == "unknown" + + +# ========================================================================= +# SessionMessageAttachment +# ========================================================================= + + +class TestSessionMessageAttachment: + """Tests for the SessionMessageAttachment dataclass.""" + + def test_create_minimal(self): + att = SessionMessageAttachment(url="https://example.com/f.png") + assert att.url == "https://example.com/f.png" + assert att.name is None + assert att.type is None + + def test_create_full(self): + att = SessionMessageAttachment(url="https://example.com/f.png", name="f.png", type="image") + assert att.name == "f.png" + assert att.type == "image" + + def test_serialization_roundtrip(self): + att = SessionMessageAttachment(url="https://example.com/f.png", name="f.png", type="image") + d = att.to_dict() + assert d["url"] == "https://example.com/f.png" + restored = SessionMessageAttachment.from_dict(d) + assert restored.url == att.url + assert restored.name == att.name + assert restored.type == att.type + + +# ========================================================================= +# SessionMessage +# ========================================================================= + + +class TestSessionMessage: + """Tests for the SessionMessage dataclass.""" + + def test_from_dict_maps_camel_case(self): + msg = SessionMessage.from_dict(SAMPLE_MESSAGE_DICT) + assert msg.id == "msg_001" + assert msg.session_id == "sess_001" + assert msg.user_id == "user_42" + assert msg.agent_id == "agent_99" + assert msg.role == "user" + assert msg.content == "Hello!" + assert msg.sequence == 1 + assert msg.request_id == "req_001" + assert msg.reaction is None + assert msg.attachments is None + assert msg.created_at == "2025-06-01T10:01:00Z" + + def test_to_dict_uses_camel_case(self): + msg = SessionMessage( + id="m1", + session_id="s1", + user_id="u1", + agent_id="a1", + role="assistant", + content="Hi", + sequence=2, + created_at="2025-01-01", + ) + d = msg.to_dict() + assert d["sessionId"] == "s1" + assert d["userId"] == "u1" + assert d["agentId"] == "a1" + assert d["requestId"] is None + assert d["createdAt"] == "2025-01-01" + + def test_with_attachments(self): + data = { + **SAMPLE_MESSAGE_DICT, + "attachments": [{"url": "https://example.com/img.png", "name": "img.png", "type": "image"}], + } + msg = SessionMessage.from_dict(data) + assert len(msg.attachments) == 1 + assert isinstance(msg.attachments[0], SessionMessageAttachment) + assert msg.attachments[0].url == "https://example.com/img.png" + + +# ========================================================================= +# Session — dataclass / serialization +# ========================================================================= + + +class TestSessionSerialization: + """Tests for Session dataclass creation and serialization.""" + + def test_from_dict_maps_camel_case(self): + session = Session.from_dict(SAMPLE_SESSION_DICT) + assert session.id == "sess_001" + assert session.user_id == "user_42" + assert session.agent_id == "agent_99" + assert session.name == "My Session" + assert session.status == "active" + assert session.run_status == "idle" + assert session.message_count == 3 + assert session.last_message_preview == "Hello there" + assert session.last_message_at == "2025-06-01T12:00:00Z" + assert session.created_at == "2025-06-01T10:00:00Z" + assert session.updated_at == "2025-06-01T12:00:00Z" + + def test_build_save_payload_create(self): + session = Session(agent_id="agent_99", name="New") + payload = session.build_save_payload() + assert payload == {"agentId": "agent_99", "name": "New", "status": "active"} + + def test_build_save_payload_excludes_readonly(self): + session = Session.from_dict(SAMPLE_SESSION_DICT) + payload = session.build_save_payload() + # Should NOT include read-only fields + assert "userId" not in payload + assert "runStatus" not in payload + assert "messageCount" not in payload + assert "createdAt" not in payload + assert "updatedAt" not in payload + assert "lastMessagePreview" not in payload + assert "lastMessageAt" not in payload + + def test_build_save_payload_update(self): + session = Session.from_dict(SAMPLE_SESSION_DICT) + session.name = "Renamed" + session.status = "archived" + payload = session.build_save_payload() + assert payload["name"] == "Renamed" + assert payload["status"] == "archived" + assert payload["agentId"] == "agent_99" + + def test_default_status_is_active(self): + session = Session(agent_id="a1") + assert session.status == "active" + + def test_repr(self): + session = Session(id="sess_001", name="My Session", agent_id="a1") + r = repr(session) + assert "Session" in r + assert "sess_001" in r + + +# ========================================================================= +# Session — CRUD operations (mocked) +# ========================================================================= + + +class TestSessionCreate: + """Tests for Session.save() (create path).""" + + def test_save_creates_new_session(self): + ctx = _make_mock_context() + ctx.client.request.return_value = { + **SAMPLE_SESSION_DICT, + "id": "new_sess", + } + BoundSession = _bound_session_class(ctx) + session = BoundSession(agent_id="agent_99", name="Test") + + session.save() + + ctx.client.request.assert_called_once() + args, kwargs = ctx.client.request.call_args + assert args[0] == "post" + assert "v1/sessions" in args[1] + assert kwargs["json"]["agentId"] == "agent_99" + assert session.id == "new_sess" + + def test_save_updates_existing_session(self): + ctx = _make_mock_context() + ctx.client.request.side_effect = [ + # First call: create + {**SAMPLE_SESSION_DICT, "id": "sess_001"}, + # Second call: update + None, + ] + BoundSession = _bound_session_class(ctx) + session = BoundSession(agent_id="agent_99", name="Test") + session.save() + assert session.id == "sess_001" + + # Now update + session.name = "Renamed" + session.save() + + second_call = ctx.client.request.call_args_list[1] + assert second_call[0][0] == "put" + assert "sess_001" in second_call[0][1] + + +class TestSessionGet: + """Tests for Session.get().""" + + def test_get_by_id(self): + ctx = _make_mock_context() + ctx.client.get.return_value = SAMPLE_SESSION_DICT + BoundSession = _bound_session_class(ctx) + + session = BoundSession.get("sess_001") + + ctx.client.get.assert_called_once() + call_path = ctx.client.get.call_args[0][0] + assert "v1/sessions/sess_001" in call_path + assert session.id == "sess_001" + assert session.agent_id == "agent_99" + + +class TestSessionDelete: + """Tests for Session.delete().""" + + def test_delete_session(self): + ctx = _make_mock_context() + ctx.client.request_raw.return_value = Mock(status_code=200, ok=True) + BoundSession = _bound_session_class(ctx) + session = BoundSession.from_dict(SAMPLE_SESSION_DICT) + session.context = ctx + session._update_saved_state() + + result = session.delete() + + ctx.client.request_raw.assert_called_once() + args = ctx.client.request_raw.call_args[0] + assert args[0] == "delete" + assert "sess_001" in args[1] + assert result.completed is True + + +class TestSessionList: + """Tests for Session.list() classmethod.""" + + def test_list_returns_sessions(self): + ctx = _make_mock_context() + ctx.client.request.return_value = [ + SAMPLE_SESSION_DICT, + {**SAMPLE_SESSION_DICT, "id": "sess_002", "name": "Second"}, + ] + BoundSession = _bound_session_class(ctx) + + sessions = BoundSession.list() + + ctx.client.request.assert_called_once() + args, kwargs = ctx.client.request.call_args + assert args[0] == "get" + assert "v1/sessions" in args[1] + assert len(sessions) == 2 + assert sessions[0].id == "sess_001" + assert sessions[1].id == "sess_002" + + def test_list_with_agent_id_filter(self): + ctx = _make_mock_context() + ctx.client.request.return_value = [SAMPLE_SESSION_DICT] + BoundSession = _bound_session_class(ctx) + + BoundSession.list(agent_id="agent_99") + + _, kwargs = ctx.client.request.call_args + assert kwargs["params"]["agentId"] == "agent_99" + + def test_list_with_status_filter(self): + ctx = _make_mock_context() + ctx.client.request.return_value = [] + BoundSession = _bound_session_class(ctx) + + BoundSession.list(status="completed") + + _, kwargs = ctx.client.request.call_args + assert kwargs["params"]["status"] == "completed" + + def test_list_with_user_id_filter(self): + ctx = _make_mock_context() + ctx.client.request.return_value = [] + BoundSession = _bound_session_class(ctx) + + BoundSession.list(user_id="user_42") + + _, kwargs = ctx.client.request.call_args + assert kwargs["params"]["userId"] == "user_42" + + def test_list_with_all_filters(self): + ctx = _make_mock_context() + ctx.client.request.return_value = [] + BoundSession = _bound_session_class(ctx) + + BoundSession.list(agent_id="a1", status="active", user_id="u1") + + _, kwargs = ctx.client.request.call_args + params = kwargs["params"] + assert params == {"agentId": "a1", "status": "active", "userId": "u1"} + + def test_list_empty_result(self): + ctx = _make_mock_context() + ctx.client.request.return_value = [] + BoundSession = _bound_session_class(ctx) + + sessions = BoundSession.list() + assert sessions == [] + + +# ========================================================================= +# Session — message operations +# ========================================================================= + + +class TestSessionMessages: + """Tests for Session message methods.""" + + def _make_session(self, ctx=None): + ctx = ctx or _make_mock_context() + BoundSession = _bound_session_class(ctx) + session = BoundSession.from_dict(SAMPLE_SESSION_DICT) + session.context = ctx + session._update_saved_state() + return session + + def test_messages_returns_list(self): + ctx = _make_mock_context() + ctx.client.request.return_value = [SAMPLE_MESSAGE_DICT] + session = self._make_session(ctx) + + messages = session.messages() + + ctx.client.request.assert_called_once() + args = ctx.client.request.call_args[0] + assert args[0] == "get" + assert "sess_001/messages" in args[1] + assert len(messages) == 1 + assert isinstance(messages[0], SessionMessage) + assert messages[0].id == "msg_001" + + def test_messages_empty(self): + ctx = _make_mock_context() + ctx.client.request.return_value = [] + session = self._make_session(ctx) + + messages = session.messages() + assert messages == [] + + def test_add_message_basic(self): + ctx = _make_mock_context() + ctx.client.request.return_value = SAMPLE_MESSAGE_DICT + session = self._make_session(ctx) + + msg = session.add_message(role="user", content="Hello!") + + ctx.client.request.assert_called_once() + args, kwargs = ctx.client.request.call_args + assert args[0] == "post" + assert "sess_001/messages" in args[1] + assert kwargs["json"]["role"] == "user" + assert kwargs["json"]["content"] == "Hello!" + assert "requestId" not in kwargs["json"] + assert "attachments" not in kwargs["json"] + assert isinstance(msg, SessionMessage) + + def test_add_message_with_request_id(self): + ctx = _make_mock_context() + ctx.client.request.return_value = SAMPLE_MESSAGE_DICT + session = self._make_session(ctx) + + session.add_message(role="user", content="Hi", request_id="req_xyz") + + payload = ctx.client.request.call_args[1]["json"] + assert payload["requestId"] == "req_xyz" + + def test_add_message_with_attachments(self): + ctx = _make_mock_context() + ctx.client.request.return_value = SAMPLE_MESSAGE_DICT + session = self._make_session(ctx) + + attachments = [{"url": "https://example.com/f.png", "name": "f.png", "type": "image"}] + session.add_message(role="user", content="See image", attachments=attachments) + + payload = ctx.client.request.call_args[1]["json"] + assert len(payload["attachments"]) == 1 + assert payload["attachments"][0]["url"] == "https://example.com/f.png" + + def test_add_message_with_files(self): + ctx = _make_mock_context() + ctx.client.request.return_value = SAMPLE_MESSAGE_DICT + session = self._make_session(ctx) + + with ( + patch("aixplain.v2.upload_utils.FileUploader") as MockUploader, + patch("aixplain.v2.upload_utils.MimeTypeDetector") as MockDetector, + ): + instance = MockUploader.return_value + instance.upload.return_value = "https://cdn.example.com/uploaded.png" + MockDetector.detect_mime_type.return_value = "image/png" + + session.add_message(role="user", content="File", files=["/tmp/photo.png"]) + + instance.upload.assert_called_once_with( + "/tmp/photo.png", + is_temp=True, + return_download_link=True, + ) + payload = ctx.client.request.call_args[1]["json"] + assert len(payload["attachments"]) == 1 + assert payload["attachments"][0]["url"] == "https://cdn.example.com/uploaded.png" + assert payload["attachments"][0]["name"] == "photo.png" + assert payload["attachments"][0]["type"] == "image" + + def test_add_message_merges_attachments_and_files(self): + ctx = _make_mock_context() + ctx.client.request.return_value = SAMPLE_MESSAGE_DICT + session = self._make_session(ctx) + + with ( + patch("aixplain.v2.upload_utils.FileUploader") as MockUploader, + patch("aixplain.v2.upload_utils.MimeTypeDetector") as MockDetector, + ): + instance = MockUploader.return_value + instance.upload.return_value = "https://cdn.example.com/doc.pdf" + MockDetector.detect_mime_type.return_value = "application/pdf" + + existing = [{"url": "https://example.com/a.txt", "name": "a.txt", "type": "text"}] + session.add_message( + role="user", + content="Both", + attachments=existing, + files=["/tmp/doc.pdf"], + ) + + payload = ctx.client.request.call_args[1]["json"] + assert len(payload["attachments"]) == 2 + assert payload["attachments"][0]["url"] == "https://example.com/a.txt" + assert payload["attachments"][1]["url"] == "https://cdn.example.com/doc.pdf" + assert payload["attachments"][1]["type"] == "document" + + def test_get_message(self): + ctx = _make_mock_context() + ctx.client.request.return_value = SAMPLE_MESSAGE_DICT + session = self._make_session(ctx) + + msg = session.get_message("msg_001") + + args = ctx.client.request.call_args[0] + assert args[0] == "get" + assert "sess_001/messages/msg_001" in args[1] + assert isinstance(msg, SessionMessage) + assert msg.id == "msg_001" + + def test_delete_message(self): + ctx = _make_mock_context() + ctx.client.request_raw.return_value = Mock(status_code=200, ok=True) + session = self._make_session(ctx) + + session.delete_message("msg_001") + + args = ctx.client.request_raw.call_args[0] + assert args[0] == "delete" + assert "sess_001/messages/msg_001" in args[1] + + def test_react_like(self): + ctx = _make_mock_context() + ctx.client.request.return_value = {**SAMPLE_MESSAGE_DICT, "reaction": "LIKE"} + session = self._make_session(ctx) + + msg = session.react("msg_001", "LIKE") + + args, kwargs = ctx.client.request.call_args + assert args[0] == "post" + assert "sess_001/messages/msg_001/reaction" in args[1] + assert kwargs["json"]["reaction"] == "LIKE" + assert msg.reaction == "LIKE" + + def test_react_clear(self): + ctx = _make_mock_context() + ctx.client.request.return_value = {**SAMPLE_MESSAGE_DICT, "reaction": None} + session = self._make_session(ctx) + + msg = session.react("msg_001", None) + + payload = ctx.client.request.call_args[1]["json"] + assert payload["reaction"] is None + assert msg.reaction is None + + +class TestSessionValidation: + """Tests that session methods enforce valid state.""" + + def _make_unsaved_session(self): + ctx = _make_mock_context() + BoundSession = _bound_session_class(ctx) + return BoundSession(agent_id="a1") + + def test_messages_requires_saved_session(self): + session = self._make_unsaved_session() + with pytest.raises(ValidationError): + session.messages() + + def test_add_message_requires_saved_session(self): + session = self._make_unsaved_session() + with pytest.raises(ValidationError): + session.add_message(role="user", content="hi") + + def test_get_message_requires_saved_session(self): + session = self._make_unsaved_session() + with pytest.raises(ValidationError): + session.get_message("msg_001") + + def test_delete_message_requires_saved_session(self): + session = self._make_unsaved_session() + with pytest.raises(ValidationError): + session.delete_message("msg_001") + + def test_react_requires_saved_session(self): + session = self._make_unsaved_session() + with pytest.raises(ValidationError): + session.react("msg_001", "LIKE") + + +# ========================================================================= +# Error handling +# ========================================================================= + + +class TestSessionErrorHandling: + """Tests that backend errors are handled gracefully.""" + + def _make_session(self, ctx=None): + ctx = ctx or _make_mock_context() + BoundSession = _bound_session_class(ctx) + session = BoundSession.from_dict(SAMPLE_SESSION_DICT) + session.context = ctx + session._update_saved_state() + return session + + # --- list() errors --- + + def test_list_api_error_propagates(self): + ctx = _make_mock_context() + ctx.client.request.side_effect = APIError("Unauthorized", status_code=401) + BoundSession = _bound_session_class(ctx) + + with pytest.raises(APIError, match="Unauthorized"): + BoundSession.list() + + def test_list_non_list_response_raises_resource_error(self): + ctx = _make_mock_context() + ctx.client.request.return_value = {"error": "something went wrong"} + BoundSession = _bound_session_class(ctx) + + with pytest.raises(ResourceError, match="Expected a list of sessions"): + BoundSession.list() + + def test_list_malformed_item_raises_resource_error(self): + ctx = _make_mock_context() + ctx.client.request.return_value = ["not_a_dict"] + BoundSession = _bound_session_class(ctx) + + with pytest.raises(ResourceError, match="Failed to parse session"): + BoundSession.list() + + # --- messages() errors --- + + def test_messages_api_error_propagates(self): + ctx = _make_mock_context() + ctx.client.request.side_effect = APIError("Not Found", status_code=404) + session = self._make_session(ctx) + + with pytest.raises(APIError, match="Not Found"): + session.messages() + + def test_messages_non_list_response_raises_resource_error(self): + ctx = _make_mock_context() + ctx.client.request.return_value = {"error": "bad"} + session = self._make_session(ctx) + + with pytest.raises(ResourceError, match="Expected a list of messages"): + session.messages() + + # --- add_message() errors --- + + def test_add_message_api_error_propagates(self): + ctx = _make_mock_context() + ctx.client.request.side_effect = APIError("Bad Request", status_code=400) + session = self._make_session(ctx) + + with pytest.raises(APIError, match="Bad Request"): + session.add_message(role="user", content="hello") + + def test_add_message_malformed_response_raises_resource_error(self): + ctx = _make_mock_context() + ctx.client.request.return_value = "not_a_dict" + session = self._make_session(ctx) + + with pytest.raises(ResourceError, match="Failed to parse session message"): + session.add_message(role="user", content="hello") + + def test_add_message_file_upload_error_wraps_with_context(self): + ctx = _make_mock_context() + session = self._make_session(ctx) + + with patch("aixplain.v2.upload_utils.FileUploader") as MockUploader: + instance = MockUploader.return_value + instance.upload.side_effect = Exception("S3 timeout") + + with pytest.raises(ResourceError, match="Failed to upload file.*photo.png.*session.*sess_001"): + session.add_message(role="user", content="file", files=["/tmp/photo.png"]) + + # --- get_message() errors --- + + def test_get_message_not_found(self): + ctx = _make_mock_context() + ctx.client.request.side_effect = APIError("Message not found", status_code=404) + session = self._make_session(ctx) + + with pytest.raises(APIError, match="Message not found"): + session.get_message("nonexistent") + + # --- delete_message() errors --- + + def test_delete_message_not_found(self): + ctx = _make_mock_context() + ctx.client.request_raw.side_effect = APIError("Message not found", status_code=404) + session = self._make_session(ctx) + + with pytest.raises(APIError, match="Message not found"): + session.delete_message("nonexistent") + + # --- react() errors --- + + def test_react_to_user_message_error(self): + ctx = _make_mock_context() + ctx.client.request.side_effect = APIError("Only assistant messages can be liked/disliked", status_code=400) + session = self._make_session(ctx) + + with pytest.raises(APIError, match="Only assistant messages"): + session.react("msg_001", "LIKE") + + def test_react_malformed_response_raises_resource_error(self): + ctx = _make_mock_context() + ctx.client.request.return_value = "bad_response" + session = self._make_session(ctx) + + with pytest.raises(ResourceError, match="Failed to parse session message"): + session.react("msg_001", "LIKE") + + +# ========================================================================= +# Agent — session convenience methods +# ========================================================================= + + +class TestAgentCreateSession: + """Tests for Agent.create_session().""" + + def test_create_session_calls_save(self): + ctx = _make_mock_context() + # Mock Session class bound to context + mock_session_instance = Mock() + mock_session_instance.save.return_value = mock_session_instance + mock_session_instance.id = "new_sess" + + MockSession = Mock(return_value=mock_session_instance) + ctx.Session = MockSession + + BoundAgent = _bound_agent_class(ctx) + agent = BoundAgent(id="agent_99", name="Test Agent") + + session = agent.create_session(name="Chat 1") + + MockSession.assert_called_once_with(agent_id="agent_99", name="Chat 1") + mock_session_instance.save.assert_called_once() + assert session is mock_session_instance + + def test_create_session_without_name(self): + ctx = _make_mock_context() + mock_session_instance = Mock() + mock_session_instance.save.return_value = mock_session_instance + MockSession = Mock(return_value=mock_session_instance) + ctx.Session = MockSession + + BoundAgent = _bound_agent_class(ctx) + agent = BoundAgent(id="agent_99", name="Test Agent") + + agent.create_session() + + MockSession.assert_called_once_with(agent_id="agent_99", name=None) + + def test_create_session_with_history(self): + ctx = _make_mock_context() + mock_session_instance = Mock() + mock_session_instance.save.return_value = mock_session_instance + mock_session_instance.id = "new_sess" + MockSession = Mock(return_value=mock_session_instance) + ctx.Session = MockSession + + BoundAgent = _bound_agent_class(ctx) + agent = BoundAgent(id="agent_99", name="Test Agent") + + history = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + session = agent.create_session(name="With History", history=history) + + mock_session_instance.save.assert_called_once() + assert mock_session_instance.add_message.call_count == 3 + calls = mock_session_instance.add_message.call_args_list + assert calls[0] == call(role="user", content="Hello", attachments=None, files=None) + assert calls[1] == call(role="assistant", content="Hi there!", attachments=None, files=None) + assert calls[2] == call(role="user", content="How are you?", attachments=None, files=None) + + def test_create_session_with_history_attachments_and_files(self): + ctx = _make_mock_context() + mock_session_instance = Mock() + mock_session_instance.save.return_value = mock_session_instance + mock_session_instance.id = "new_sess" + MockSession = Mock(return_value=mock_session_instance) + ctx.Session = MockSession + + BoundAgent = _bound_agent_class(ctx) + agent = BoundAgent(id="agent_99", name="Test Agent") + + att = [{"url": "https://example.com/img.png", "name": "img.png", "type": "image"}] + history = [ + {"role": "user", "content": "See image", "attachments": att}, + {"role": "user", "content": "And file", "files": ["/tmp/data.csv"]}, + {"role": "assistant", "content": "Got it"}, + ] + agent.create_session(name="Rich History", history=history) + + calls = mock_session_instance.add_message.call_args_list + assert calls[0] == call(role="user", content="See image", attachments=att, files=None) + assert calls[1] == call(role="user", content="And file", attachments=None, files=["/tmp/data.csv"]) + assert calls[2] == call(role="assistant", content="Got it", attachments=None, files=None) + + def test_create_session_with_invalid_history(self): + ctx = _make_mock_context() + ctx.Session = Mock() + + BoundAgent = _bound_agent_class(ctx) + agent = BoundAgent(id="agent_99", name="Test Agent") + + with pytest.raises(ValueError): + agent.create_session(history=[{"bad": "format"}]) + + def test_create_session_requires_saved_agent(self): + ctx = _make_mock_context() + BoundAgent = _bound_agent_class(ctx) + agent = BoundAgent(name="Unsaved Agent") + + with pytest.raises(ValueError, match="must be saved"): + agent.create_session() + + +class TestAgentListSessions: + """Tests for Agent.list_sessions().""" + + def test_list_sessions_delegates_to_session_list(self): + ctx = _make_mock_context() + mock_sessions = [Mock(), Mock()] + ctx.Session = Mock() + ctx.Session.list.return_value = mock_sessions + + BoundAgent = _bound_agent_class(ctx) + agent = BoundAgent(id="agent_99", name="Test Agent") + + result = agent.list_sessions() + + ctx.Session.list.assert_called_once_with(agent_id="agent_99", status=None) + assert result == mock_sessions + + def test_list_sessions_with_status_filter(self): + ctx = _make_mock_context() + ctx.Session = Mock() + ctx.Session.list.return_value = [] + + BoundAgent = _bound_agent_class(ctx) + agent = BoundAgent(id="agent_99", name="Test Agent") + + agent.list_sessions(status="completed") + + ctx.Session.list.assert_called_once_with(agent_id="agent_99", status="completed") + + def test_list_sessions_requires_saved_agent(self): + ctx = _make_mock_context() + BoundAgent = _bound_agent_class(ctx) + agent = BoundAgent(name="Unsaved Agent") + + with pytest.raises(ValueError, match="must be saved"): + agent.list_sessions() + + +# ========================================================================= +# Core registration +# ========================================================================= + + +class TestCoreSessionRegistration: + """Tests that Session is properly registered in Aixplain.""" + + def test_session_registered_on_init(self): + from aixplain.v2.core import Aixplain + + ax = Aixplain(api_key="test_key") + assert ax.Session is not None + assert issubclass(ax.Session, Session) + assert ax.Session.context is ax + + def test_session_unique_per_instance(self): + from aixplain.v2.core import Aixplain + + ax1 = Aixplain(api_key="key1") + ax2 = Aixplain(api_key="key2") + assert ax1.Session is not ax2.Session + assert ax1.Session.context is ax1 + assert ax2.Session.context is ax2 + + def test_session_enums_on_aixplain_class(self): + from aixplain.v2.core import Aixplain + + assert Aixplain.SessionStatus is SessionStatus + assert Aixplain.RunStatus is RunStatus + assert Aixplain.MessageRole is MessageRole + assert Aixplain.Reaction is Reaction + assert Aixplain.AttachmentType is AttachmentType