-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathagent.py
More file actions
75 lines (49 loc) · 1.89 KB
/
agent.py
File metadata and controls
75 lines (49 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import datetime
import os
import uuid
from typing import List, Optional
import dotenv
import uvicorn
from fastapi import FastAPI, status
from fastapi.responses import FileResponse, Response
from langchain.agents import AgentExecutor
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
from langchain.tools import tool
from langchain_community.chat_models import ChatOpenAI, AzureChatOpenAI
from pydantic import BaseModel, Field
app = FastAPI()
@app.get('/manifest.json')
async def get_manifest() -> Response:
return FileResponse('manifest.json')
@app.get('/logo.png')
async def get_logo() -> Response:
return FileResponse('logo.png')
class SessionBase(BaseModel):
locales: List[str]
class Session(SessionBase):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
@app.post('/sessions', status_code=status.HTTP_201_CREATED)
async def create_session(req: SessionBase) -> Session:
ret = Session(**req.model_dump())
return ret
class QuestionRequest(BaseModel):
question: Optional[str] = ""
class AgentStep(BaseModel):
action: str = "message"
value: str
class QuestionResponse(BaseModel):
steps: List[AgentStep]
@tool
def clock():
"""gets the current time"""
return str(datetime.datetime.now())
@app.post('/sessions/{session_id}/questions', status_code=status.HTTP_200_OK)
async def answer_question(session_id: str, req: QuestionRequest) -> QuestionResponse:
resp = build_agent().invoke(req.question)
return QuestionResponse(steps=[AgentStep(value=resp['output'])])
def build_agent() -> AgentExecutor:
llm = ChatOpenAI(model_name=os.getenv("MODEL_NAME"), temperature=0.7, verbose=True, streaming=True)
return create_conversational_retrieval_agent(llm, [clock], max_iterations=3)
if __name__ == "__main__":
dotenv.load_dotenv()
uvicorn.run("agent:app", host="0.0.0.0", port=8000, log_level="info", reload=True)