Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions openai-agents/tour-of-agents/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
)
from app.advanced.rollback_agent import agent_service as booking_with_rollback_agent
from app.advanced.manual_loop_agent import manual_loop_agent
from app.advanced.mcp import chat as mcp_chat
from app.advanced.mcp_with_approval import chat as mcp_with_approvals_chat
from app.advanced.websearch import chat as websearch_chat

from app.parallel_agents import agent_service as parallel_agent_claim_approval
from app.parallel_tools_agent import agent_service as parallel_tool_claim_agent
from app.utils.utils import fraud_agent_service, rate_comparison_agent_service, eligibility_agent_service
from app.utils.utils import (
fraud_agent_service,
rate_comparison_agent_service,
eligibility_agent_service,
)

# Create Restate app with all tour services
app = restate.app(
Expand All @@ -40,14 +47,17 @@
# Advanced patterns
booking_with_rollback_agent,
manual_loop_agent,
mcp_chat,
mcp_with_approvals_chat,
websearch_chat,
# Error handling
# Parallel processing
parallel_agent_claim_approval,
parallel_tool_claim_agent,
# Utils
fraud_agent_service,
eligibility_agent_service,
rate_comparison_agent_service
rate_comparison_agent_service,
]
)

Expand Down
14 changes: 11 additions & 3 deletions openai-agents/tour-of-agents/app/advanced/manual_loop_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import restate
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam, ChatCompletionMessageFunctionToolCall, \
ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessageParam,
ChatCompletionMessageFunctionToolCall,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)
from pydantic import BaseModel
from restate import Context
from openai import OpenAI, pydantic_function_tool
Expand Down Expand Up @@ -61,7 +66,10 @@ def llm_call() -> ChatCompletion:

# Check if we need to call tools
for tool_call in assistant_message.tool_calls:
if isinstance(tool_call, ChatCompletionMessageFunctionToolCall) and tool_call.function.name == "get_weather":
if (
isinstance(tool_call, ChatCompletionMessageFunctionToolCall)
and tool_call.function.name == "get_weather"
):
req = WeatherRequest.model_validate_json(tool_call.function.arguments)
tool_output = await ctx.run_typed(
"Get weather", fetch_weather, city=req.city
Expand Down
46 changes: 46 additions & 0 deletions openai-agents/tour-of-agents/app/advanced/mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import List

from agents import (
Agent,
Runner,
HostedMCPTool,
TResponseInputItem,
)
from openai.types.responses.tool_param import Mcp
from restate import VirtualObject, ObjectContext, ObjectSharedContext

from app.utils.middleware import Runner, RestateSession
from app.utils.models import ChatMessage

chat = VirtualObject("McpChat")


@chat.handler()
async def message(_ctx: ObjectContext, chat_message: ChatMessage) -> str:

result = await Runner.run(
Agent(
name="Assistant",
instructions="You are a helpful assistant.",
tools=[
HostedMCPTool(
tool_config=Mcp(
type="mcp",
server_label="restate_docs",
server_description="A knowledge base about Restate's documentation.",
server_url="https://docs.restate.dev/mcp",
require_approval="never",
),
)
],
),
input=chat_message.message,
session=RestateSession(),
)
return result.final_output


@chat.handler(kind="shared")
async def get_history(_ctx: ObjectSharedContext) -> List[TResponseInputItem]:
session = RestateSession()
return await session.get_items()
61 changes: 61 additions & 0 deletions openai-agents/tour-of-agents/app/advanced/mcp_with_approval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import List

from agents import Agent, Runner, HostedMCPTool, TResponseInputItem, MCPToolApprovalRequest, \
MCPToolApprovalFunctionResult
from openai.types.responses.tool_param import Mcp
from restate import VirtualObject, ObjectContext, ObjectSharedContext

from app.utils.middleware import Runner, RestateSession
from app.utils.models import ChatMessage
from app.utils.utils import request_human_review, request_mcp_approval

chat = VirtualObject("McpWithApprovalsChat")

async def approve_func(req: MCPToolApprovalRequest) -> MCPToolApprovalFunctionResult:
restate_context = req.ctx_wrapper.context

# Request human review
approval_id, approval_promise = restate_context.awakeable(type_hint=bool)
await restate_context.run_typed(
"Approve MCP tool", request_mcp_approval, mcp_tool_name=req.data.name, awakeable_id=approval_id
)
# Wait for human approval
approved = await approval_promise
if not approved:
return {"approve": approved, "reason": "User denied"}
return {"approve": approved}



@chat.handler()
async def message(ctx: ObjectContext, chat_message: ChatMessage) -> str:

result = await Runner.run(
Agent(
name="Assistant",
instructions="You are a helpful assistant.",
tools = [
HostedMCPTool(
tool_config=Mcp(
type="mcp",
server_label="restate_docs",
server_description="A knowledge base about Restate's documentation.",
server_url="https://docs.restate.dev/mcp"
),
on_approval_request=approve_func
# or use require_approval="never" in the tool_config to disable approvals
)
],
),
input=chat_message.message,
session=RestateSession(),
context=ctx
)

return result.final_output


@chat.handler(kind="shared")
async def get_history(_ctx: ObjectSharedContext) -> List[TResponseInputItem]:
session = RestateSession()
return await session.get_items()
27 changes: 13 additions & 14 deletions openai-agents/tour-of-agents/app/advanced/rollback_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Callable

import restate
from agents import Agent, RunConfig, Runner, function_tool, RunContextWrapper
from agents import Agent, RunConfig, Runner, RunContextWrapper
from pydantic import Field, BaseModel, ConfigDict
from restate import TerminalError

from app.utils.middleware import DurableModelCalls, raise_restate_errors
from app.utils.middleware import Runner, function_tool
from app.utils.models import HotelBooking, FlightBooking, BookingPrompt, BookingResult
from app.utils.utils import (
reserve_hotel,
Expand All @@ -25,7 +25,7 @@ class BookingContext(BaseModel):


# Functions raise terminal errors instead of feeding them back to the agent
@function_tool(failure_error_function=raise_restate_errors)
@function_tool
async def book_hotel(
wrapper: RunContextWrapper[BookingContext], booking: HotelBooking
) -> BookingResult:
Expand All @@ -41,11 +41,14 @@ async def book_hotel(

# Execute the workflow step
return await booking_context.restate_context.run_typed(
"Book hotel", reserve_hotel, booking_id=booking_context.booking_id, booking=booking
"Book hotel",
reserve_hotel,
booking_id=booking_context.booking_id,
booking=booking,
)


@function_tool(failure_error_function=raise_restate_errors)
@function_tool
async def book_flight(
wrapper: RunContextWrapper[BookingContext], booking: FlightBooking
) -> BookingResult:
Expand All @@ -58,7 +61,10 @@ async def book_flight(
)
)
return await booking_context.restate_context.run_typed(
"Book flight", reserve_flight, booking_id=booking_context.booking_id, booking=booking
"Book flight",
reserve_flight,
booking_id=booking_context.booking_id,
booking=booking,
)


Expand All @@ -83,14 +89,7 @@ async def book(restate_context: restate.Context, prompt: BookingPrompt) -> str:
)

try:
result = await Runner.run(
booking_agent,
input=prompt.message,
context=booking_context,
run_config=RunConfig(
model="gpt-4o", model_provider=DurableModelCalls(restate_context)
),
)
result = await Runner.run(booking_agent, input=prompt.message)
except TerminalError as e:
# Run all the rollback actions on terminal errors
for compensation in reversed(booking_context.on_rollback):
Expand Down
37 changes: 37 additions & 0 deletions openai-agents/tour-of-agents/app/advanced/websearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List

from agents import (
Agent,
Runner,
WebSearchTool,
TResponseInputItem,
)
from restate import VirtualObject, ObjectContext, ObjectSharedContext

from app.utils.middleware import Runner, RestateSession
from app.utils.models import ChatMessage

chat = VirtualObject("WebsearchChat")


@chat.handler()
async def message(restate_context: ObjectContext, chat_message: ChatMessage) -> str:

result = await Runner.run(
Agent(
name="Assistant",
instructions="You are a helpful assistant.",
tools=[
WebSearchTool()
],
),
input=chat_message.message,
session=RestateSession(),
)
return result.final_output


@chat.handler(kind="shared")
async def get_history(_ctx: ObjectSharedContext) -> List[TResponseInputItem]:
session = RestateSession()
return await session.get_items()
25 changes: 9 additions & 16 deletions openai-agents/tour-of-agents/app/chat.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
from agents import Agent, RunConfig, Runner, ModelSettings
from typing import List

from agents import Agent, Runner, TResponseInputItem
from restate import VirtualObject, ObjectContext, ObjectSharedContext

from app.utils.middleware import DurableModelCalls, RestateSession
from app.utils.middleware import Runner, RestateSession
from app.utils.models import ChatMessage

chat = VirtualObject("Chat")


@chat.handler()
async def message(restate_context: ObjectContext, chat_message: ChatMessage) -> dict:

restate_session = await RestateSession.create(
session_id=restate_context.key(), ctx=restate_context
)

async def message(_ctx: ObjectContext, chat_message: ChatMessage) -> dict:
result = await Runner.run(
Agent(name="Assistant", instructions="You are a helpful assistant."),
input=chat_message.message,
run_config=RunConfig(
model="gpt-4o",
model_provider=DurableModelCalls(restate_context),
model_settings=ModelSettings(parallel_tool_calls=False),
),
session=restate_session,
session=RestateSession(),
)
return result.final_output


@chat.handler(kind="shared")
async def get_history(ctx: ObjectSharedContext):
return await ctx.get("items") or []
async def get_history(_ctx: ObjectSharedContext) -> List[TResponseInputItem]:
session = RestateSession()
return await session.get_items()
33 changes: 6 additions & 27 deletions openai-agents/tour-of-agents/app/durable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,30 @@
from agents import (
Agent,
RunConfig,
Runner,
function_tool,
RunContextWrapper,
ModelSettings,
)

from app.utils.middleware import DurableModelCalls, raise_restate_errors
from app.utils.middleware import Runner, function_tool
from app.utils.models import WeatherPrompt, WeatherRequest, WeatherResponse
from app.utils.utils import fetch_weather


@function_tool(failure_error_function=raise_restate_errors)
async def get_weather(
wrapper: RunContextWrapper[restate.Context], req: WeatherRequest
) -> WeatherResponse:
@function_tool
async def get_weather(city: WeatherRequest) -> WeatherResponse:
"""Get the current weather for a given city."""
# Do durable steps using the Restate context
restate_context = wrapper.context
return await restate_context.run_typed("Get weather", fetch_weather, city=req.city)
return await fetch_weather(city)


weather_agent = Agent[restate.Context](
weather_agent = Agent(
name="WeatherAgent",
instructions="You are a helpful agent that provides weather updates.",
tools=[get_weather],
)


agent_service = restate.Service("WeatherAgent")


@agent_service.handler()
async def run(restate_context: restate.Context, prompt: WeatherPrompt) -> str:

result = await Runner.run(
weather_agent,
input=prompt.message,
# Pass the Restate context to tools to make tool execution steps durable
context=restate_context,
# Choose any model and let Restate persist your calls
run_config=RunConfig(
model="gpt-4o",
model_provider=DurableModelCalls(restate_context),
model_settings=ModelSettings(parallel_tool_calls=False),
),
)

result = await Runner.run(weather_agent, input=prompt.message)
return result.final_output
Loading
Loading