11from enum import Enum
2- from typing import Any , Optional
2+ from typing import Any , Optional , Union
33
44from openai .types .chat import ChatCompletionMessageParam
5- from pydantic import BaseModel
5+ from pydantic import BaseModel , Field
6+ from pydantic_ai .messages import ModelRequest , ModelResponse
67
78
89class AIChatRoles (str , Enum ):
@@ -41,14 +42,34 @@ class ChatRequest(BaseModel):
4142 sessionState : Optional [Any ] = None
4243
4344
45+ class ItemPublic (BaseModel ):
46+ id : int
47+ type : str
48+ brand : str
49+ name : str
50+ description : str
51+ price : float
52+
53+ def to_str_for_rag (self ):
54+ return f"Name:{ self .name } Description:{ self .description } Price:{ self .price } Brand:{ self .brand } Type:{ self .type } "
55+
56+
57+ class ItemWithDistance (ItemPublic ):
58+ distance : float
59+
60+ def __init__ (self , ** data ):
61+ super ().__init__ (** data )
62+ self .distance = round (self .distance , 2 )
63+
64+
4465class ThoughtStep (BaseModel ):
4566 title : str
4667 description : Any
4768 props : dict = {}
4869
4970
5071class RAGContext (BaseModel ):
51- data_points : dict [int , dict [ str , Any ] ]
72+ data_points : dict [int , ItemPublic ]
5273 thoughts : list [ThoughtStep ]
5374 followup_questions : Optional [list [str ]] = None
5475
@@ -69,27 +90,39 @@ class RetrievalResponseDelta(BaseModel):
6990 sessionState : Optional [Any ] = None
7091
7192
72- class ItemPublic (BaseModel ):
73- id : int
74- type : str
75- brand : str
76- name : str
77- description : str
78- price : float
79-
80-
81- class ItemWithDistance (ItemPublic ):
82- distance : float
83-
84- def __init__ (self , ** data ):
85- super ().__init__ (** data )
86- self .distance = round (self .distance , 2 )
87-
88-
8993class ChatParams (ChatRequestOverrides ):
9094 prompt_template : str
9195 response_token_limit : int = 1024
9296 enable_text_search : bool
9397 enable_vector_search : bool
9498 original_user_query : str
95- past_messages : list [ChatCompletionMessageParam ]
99+ past_messages : list [Union [ModelRequest , ModelResponse ]]
100+
101+
102+ class Filter (BaseModel ):
103+ column : str
104+ comparison_operator : str
105+ value : Any
106+
107+
108+ class PriceFilter (Filter ):
109+ column : str = Field (default = "price" , description = "The column to filter on (always 'price' for this filter)" )
110+ comparison_operator : str = Field (description = "The operator for price comparison ('>', '<', '>=', '<=', '=')" )
111+ value : float = Field (description = "The price value to compare against (e.g., 30.00)" )
112+
113+
114+ class BrandFilter (Filter ):
115+ column : str = Field (default = "brand" , description = "The column to filter on (always 'brand' for this filter)" )
116+ comparison_operator : str = Field (description = "The operator for brand comparison ('=' or '!=')" )
117+ value : str = Field (description = "The brand name to compare against (e.g., 'AirStrider')" )
118+
119+
120+ class SearchResults (BaseModel ):
121+ query : str
122+ """The original search query"""
123+
124+ items : list [ItemPublic ]
125+ """List of items that match the search query and filters"""
126+
127+ filters : list [Filter ]
128+ """List of filters applied to the search results"""
0 commit comments