diff --git a/examples/interior_design_assistant/api.py b/examples/interior_design_assistant/api.py index 80591ad3a..bae8c2fe7 100644 --- a/examples/interior_design_assistant/api.py +++ b/examples/interior_design_assistant/api.py @@ -21,11 +21,17 @@ from llama_stack_client import LlamaStackClient from llama_stack_client.types import MemoryToolDefinition, SamplingParams from llama_stack_client.types.agent_create_params import AgentConfig +from pydantic import BaseModel from termcolor import cprint MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct" +class Output(BaseModel): + description: str + items: list[str] + + class InterioAgent: def __init__(self, document_dir: str, image_dir: str): self.document_dir = document_dir @@ -44,6 +50,10 @@ async def _get_agent(self): instructions="", sampling_params=SamplingParams(strategy="greedy", temperature=0.0), enable_session_persistence=True, + response_format={ + "type": "json_schema", + "json_schema": Output.model_json_schema() + } ) response = self.client.agents.create( agent_config=agent_config, @@ -106,13 +116,9 @@ async def list_items(self, file_path: str) -> List[str]: break result = turn.output_message.content - try: - d = json.loads(result.strip()) - except Exception: - cprint(f"Error parsing JSON output: {result}", color="red") - raise + result = json.loads(result.strip()) - return d + return result async def suggest_alternatives( self, file_path: str, item: str, n: int = 3 @@ -161,13 +167,13 @@ async def suggest_alternatives( ], } - resposne = self.client.agents.session.create( + response = self.client.agents.session.create( agent_id=self.agent_id, session_name=uuid.uuid4().hex, ) generator = self.client.agents.turn.create( agent_id=self.agent_id, - session_id=resposne.session_id, + session_id=response.session_id, messages=[message], stream=True, ) @@ -178,8 +184,8 @@ async def suggest_alternatives( turn = payload.turn result = turn.output_message.content - print(result) - return [r["description"].strip() for r in json.loads(result.strip())] + result = json.loads(result.strip()) + return [result["description"].strip()] async def retrieve_images(self, description: str): """