diff --git a/crates/offeryn-core/src/lib.rs b/crates/offeryn-core/src/lib.rs index 98649b6..ea9d145 100644 --- a/crates/offeryn-core/src/lib.rs +++ b/crates/offeryn-core/src/lib.rs @@ -1,7 +1,9 @@ +pub mod client; pub mod error; pub mod server; pub mod transport; +pub use client::StdioClient; pub use error::McpError; pub use offeryn_types::{ CallToolRequest, CallToolResult, Content, InitializeResult, ListToolsResult, diff --git a/crates/offeryn-core/src/transport/mod.rs b/crates/offeryn-core/src/transport/mod.rs index d6e8004..18d29dd 100644 --- a/crates/offeryn-core/src/transport/mod.rs +++ b/crates/offeryn-core/src/transport/mod.rs @@ -1,4 +1,4 @@ -mod sse; -mod stdio; +pub(super) mod sse; +pub(super) mod stdio; pub use sse::SseServerTransport; pub use stdio::StdioServerTransport; diff --git a/crates/offeryn-core/src/transport/stdio.rs b/crates/offeryn-core/src/transport/stdio.rs index b568a90..1e597b5 100644 --- a/crates/offeryn-core/src/transport/stdio.rs +++ b/crates/offeryn-core/src/transport/stdio.rs @@ -1,6 +1,8 @@ use crate::McpServer; use axum::async_trait; use jsonrpc_core::{Call, Error, Failure, Id, Output, Request, Response, Version}; +use offeryn_types::JsonRpcMessage; +use serde::{de::DeserializeOwned, Serialize}; use std::sync::Arc; use tokio::{ io::{ @@ -10,7 +12,7 @@ use tokio::{ }; #[async_trait] -trait StdioTransport +pub trait StdioTransport where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, @@ -150,6 +152,55 @@ where { } +pub struct StdioClientTransport +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + reader: BufReader, + writer: BufWriter, +} + +impl StdioClientTransport +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + pub fn with_streams(stdin: W, stdout: R) -> Self { + Self { + reader: BufReader::new(stdout), + writer: BufWriter::new(stdin), + } + } + + pub async fn send( + &mut self, + request: JsonRpcMessage, + ) -> Result<(), Box> { + let request_json = serde_json::to_vec(&request).unwrap(); + + Self::write_message(&mut self.writer, &request_json) + .await + .map_err(Into::into) + } + + pub async fn recv( + &mut self, + ) -> Result, Box> { + let response = Self::read_message(&mut self.reader).await?; + + serde_json::from_slice(&response).map_err(Into::into) + } +} + +#[async_trait] +impl StdioTransport for StdioClientTransport +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/offeryn-types/src/lib.rs b/crates/offeryn-types/src/lib.rs index d2ca5d7..1f96b0e 100644 --- a/crates/offeryn-types/src/lib.rs +++ b/crates/offeryn-types/src/lib.rs @@ -4,7 +4,7 @@ pub use jsonrpc_core::{ Version, }; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{Map, Value}; use std::collections::HashMap; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -99,7 +99,8 @@ pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &["2024-11-05"]; #[serde(rename_all = "camelCase")] pub struct JsonRpcMessage { pub jsonrpc: String, - pub id: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, #[serde(flatten)] pub content: T, } @@ -111,6 +112,24 @@ pub struct InitializeRequest { pub params: InitializeParams, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializeResponse { + pub result: InitializeResult, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializedNotification { + pub params: InitializedNotificationParams, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializedNotificationParams { + pub meta: Map, +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListToolsResult { @@ -166,3 +185,23 @@ pub enum Content { #[serde(rename = "resource")] EmbeddedResource { uri: String, name: Option }, } + +// Collin: Pulled this over from the types generated from the MCP spec schema. +// There are missing variants for now. +#[derive(Deserialize, Serialize, Clone, Debug)] +#[serde(untagged)] +pub enum ServerResult { + InitializeResult(InitializeResult), +} + +impl From<&ServerResult> for ServerResult { + fn from(value: &ServerResult) -> Self { + value.clone() + } +} + +impl From for ServerResult { + fn from(value: InitializeResult) -> Self { + Self::InitializeResult(value) + } +}