@@ -130,11 +130,43 @@ class AgentControllerData(BaseModel):
130130 ] = None
131131
132132
133+ def save_messages_to_session_data (session_id , id , messages : List [AgentMessage ]):
134+ from llmstack .apps .app_session_utils import save_app_session_data
135+
136+ logger .info (f"Saving messages to session data: { messages } " )
137+
138+ save_app_session_data (session_id , id , [m .model_dump_json () for m in messages ])
139+
140+
141+ def load_messages_from_session_data (session_id , id ):
142+ from llmstack .apps .app_session_utils import get_app_session_data
143+
144+ messages = []
145+
146+ session_data = get_app_session_data (session_id , id )
147+ if session_data and isinstance (session_data , list ):
148+ for data in session_data :
149+ data_json = json .loads (data )
150+ if data_json ["role" ] == "system" :
151+ messages .append (AgentSystemMessage (** data_json ))
152+ elif data_json ["role" ] == "assistant" :
153+ messages .append (AgentAssistantMessage (** data_json ))
154+ elif data_json ["role" ] == "user" :
155+ messages .append (AgentUserMessage (** data_json ))
156+
157+ return messages
158+
159+
133160class AgentController :
134161 def __init__ (self , output_queue : asyncio .Queue , config : AgentControllerConfig ):
162+ self ._session_id = config .metadata .get ("session_id" )
163+ self ._controller_id = f"{ config .metadata .get ('app_uuid' )} _agent"
164+ self ._system_message = render_template (config .agent_config .system_message , {})
135165 self ._output_queue = output_queue
136166 self ._config = config
137- self ._messages : List [AgentMessage ] = []
167+ self ._messages : List [AgentMessage ] = (
168+ load_messages_from_session_data (self ._session_id , self ._controller_id ) or []
169+ )
138170 self ._llm_client = None
139171 self ._websocket = None
140172 self ._provider_config = None
@@ -254,18 +286,6 @@ def _init_llm_client(self):
254286 ),
255287 )
256288
257- self ._messages .append (
258- AgentSystemMessage (
259- role = AgentMessageRole .SYSTEM ,
260- content = [
261- AgentMessageContent (
262- type = AgentMessageContentType .TEXT ,
263- data = render_template (self ._config .agent_config .system_message , {}),
264- )
265- ],
266- )
267- )
268-
269289 async def _process_input_audio_stream (self ):
270290 if self ._input_audio_stream :
271291 async for chunk in self ._input_audio_stream .read_async ():
@@ -387,6 +407,10 @@ def process(self, data: AgentControllerData):
387407 # Actor calls this to add a message to the conversation and trigger processing
388408 self ._messages .append (data .data )
389409
410+ # This is a message from the assistant to the user, simply add it to the message to maintain state
411+ if data .type == AgentControllerDataType .AGENT_OUTPUT_END or data .type == AgentControllerDataType .TOOL_CALLS_END :
412+ return
413+
390414 try :
391415 if len (self ._messages ) > self ._config .agent_config .max_steps :
392416 raise Exception (f"Max steps ({ self ._config .agent_config .max_steps } ) exceeded: { len (self ._messages )} " )
@@ -465,7 +489,7 @@ async def process_messages(self, data: AgentControllerData):
465489 stream = True if self ._config .agent_config .stream is None else self ._config .agent_config .stream
466490 response = self ._llm_client .chat .completions .create (
467491 model = self ._config .agent_config .model ,
468- messages = client_messages ,
492+ messages = [{ "role" : "system" , "content" : self . _system_message }] + client_messages ,
469493 stream = stream ,
470494 tools = self ._config .tools ,
471495 )
@@ -703,6 +727,9 @@ async def add_ws_event_to_output_queue(self, event: Any):
703727 logger .error (f"WebSocket error: { event } " )
704728
705729 def terminate (self ):
730+ # Save to session data
731+ save_messages_to_session_data (self ._session_id , self ._controller_id , self ._messages )
732+
706733 # Create task for graceful websocket closure
707734 if hasattr (self , "_websocket" ) and self ._websocket :
708735 asyncio .run_coroutine_threadsafe (self ._websocket .close (), self ._loop )
0 commit comments