Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/offeryn-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
pub mod client;
pub mod error;
pub mod server;
pub mod transport;

pub use client::StdioClient;
pub use error::McpError;
pub use offeryn_types::{
CallToolRequest, CallToolResult, Content, InitializeResult, ListToolsResult,
Expand Down
4 changes: 2 additions & 2 deletions crates/offeryn-core/src/transport/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mod sse;
mod stdio;
pub(super) mod sse;
pub(super) mod stdio;
pub use sse::SseServerTransport;
pub use stdio::StdioServerTransport;
53 changes: 52 additions & 1 deletion crates/offeryn-core/src/transport/stdio.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::McpServer;
use axum::async_trait;
use jsonrpc_core::{Call, Error, Failure, Id, Output, Request, Response, Version};
use offeryn_types::JsonRpcMessage;
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
use tokio::{
io::{
Expand All @@ -10,7 +12,7 @@ use tokio::{
};

#[async_trait]
trait StdioTransport<R, W>
pub trait StdioTransport<R, W>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
Expand Down Expand Up @@ -150,6 +152,55 @@ where
{
}

pub struct StdioClientTransport<R, W>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
reader: BufReader<R>,
writer: BufWriter<W>,
}

impl<R, W> StdioClientTransport<R, W>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
pub fn with_streams(stdin: W, stdout: R) -> Self {
Self {
reader: BufReader::new(stdout),
writer: BufWriter::new(stdin),
}
}

pub async fn send<Req: Serialize>(
&mut self,
request: JsonRpcMessage<Req>,
) -> Result<(), Box<dyn std::error::Error>> {
let request_json = serde_json::to_vec(&request).unwrap();

Self::write_message(&mut self.writer, &request_json)
.await
.map_err(Into::into)
}

pub async fn recv<Res: DeserializeOwned>(
&mut self,
) -> Result<JsonRpcMessage<Res>, Box<dyn std::error::Error>> {
let response = Self::read_message(&mut self.reader).await?;

serde_json::from_slice(&response).map_err(Into::into)
}
}

#[async_trait]
impl<R, W> StdioTransport<R, W> for StdioClientTransport<R, W>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
43 changes: 41 additions & 2 deletions crates/offeryn-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pub use jsonrpc_core::{
Version,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::{Map, Value};
use std::collections::HashMap;

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -99,7 +99,8 @@ pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &["2024-11-05"];
#[serde(rename_all = "camelCase")]
pub struct JsonRpcMessage<T> {
pub jsonrpc: String,
pub id: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<u64>,
#[serde(flatten)]
pub content: T,
}
Expand All @@ -111,6 +112,24 @@ pub struct InitializeRequest {
pub params: InitializeParams,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
pub result: InitializeResult,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializedNotification {
pub params: InitializedNotificationParams,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializedNotificationParams {
pub meta: Map<String, Value>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListToolsResult {
Expand Down Expand Up @@ -166,3 +185,23 @@ pub enum Content {
#[serde(rename = "resource")]
EmbeddedResource { uri: String, name: Option<String> },
}

// Collin: Pulled this over from the types generated from the MCP spec schema.
// There are missing variants for now.
#[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(untagged)]
pub enum ServerResult {
InitializeResult(InitializeResult),
}

impl From<&ServerResult> for ServerResult {
fn from(value: &ServerResult) -> Self {
value.clone()
}
}

impl From<InitializeResult> for ServerResult {
fn from(value: InitializeResult) -> Self {
Self::InitializeResult(value)
}
}