diff --git a/.release-config.json b/.release-config.json index 304ccc9..d87cd6a 100644 --- a/.release-config.json +++ b/.release-config.json @@ -65,6 +65,7 @@ "pull-request-header": ":robot: Auto-generated release PR", "packages": { "crates/rust-mcp-macros": { + "release-as": "0.8.0", "release-type": "rust", "draft": false, "prerelease": false, @@ -79,6 +80,7 @@ ] }, "crates/rust-mcp-transport": { + "release-as": "0.8.0", "release-type": "rust", "draft": false, "prerelease": false, @@ -92,7 +94,8 @@ } ] }, - "crates/rust-mcp-extra": { + "crates/rust-mcp-sdk": { + "release-as": "0.8.0", "release-type": "rust", "draft": false, "prerelease": false, @@ -106,7 +109,7 @@ } ] }, - "crates/rust-mcp-sdk": { + "crates/rust-mcp-extra": { "release-type": "rust", "draft": false, "prerelease": false, diff --git a/Cargo.lock b/Cargo.lock index 314c813..8ce5772 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -120,7 +120,7 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "hyper 1.8.1", @@ -151,7 +151,7 @@ checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" dependencies = [ "bytes", "futures-core", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "mime", @@ -171,7 +171,7 @@ dependencies = [ "arc-swap", "bytes", "fs-err", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "hyper 1.8.1", "hyper-util", @@ -204,9 +204,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64ct" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" +checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a" [[package]] name = "bitflags" @@ -237,9 +237,9 @@ checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "cc" -version = "1.2.47" +version = "1.2.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd405d82c84ff7f35739f175f67d8b9fb7687a0e84ccdc78bd3568839827cf07" +checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215" dependencies = [ "find-msvc-tools", "jobserver", @@ -275,9 +275,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.54" +version = "0.1.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +checksum = "b042e5d8a74ae91bb0961acd039822472ec99f8ab0948cbf6d1369588f8be586" dependencies = [ "cc", ] @@ -759,7 +759,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http 1.3.1", + "http 1.4.0", "indexmap", "slab", "tokio", @@ -846,12 +846,11 @@ dependencies = [ [[package]] name = "http" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" dependencies = [ "bytes", - "fnv", "itoa", ] @@ -873,7 +872,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.3.1", + "http 1.4.0", ] [[package]] @@ -884,7 +883,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "pin-project-lite", ] @@ -957,7 +956,7 @@ dependencies = [ "futures-channel", "futures-core", "h2 0.4.12", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "httparse", "httpdate", @@ -975,7 +974,7 @@ version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "http 1.3.1", + "http 1.4.0", "hyper 1.8.1", "hyper-util", "rustls", @@ -1004,16 +1003,16 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.18" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e9a2a24dc5c6821e71a7030e1e14b7b632acac55c40e9d2e082c621261bb56" +checksum = "727805d60e7938b76b826a6ef209eb70eaa1812794f9424d4a4e2d740662df5f" dependencies = [ "base64 0.22.1", "bytes", "futures-channel", "futures-core", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "hyper 1.8.1", "ipnet", @@ -1100,9 +1099,9 @@ checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" [[package]] name = "icu_properties" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" dependencies = [ "icu_collections", "icu_locale_core", @@ -1114,9 +1113,9 @@ dependencies = [ [[package]] name = "icu_properties_data" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" [[package]] name = "icu_provider" @@ -1213,9 +1212,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b011eec8cc36da2aab2d5cff675ec18454fad408585853910a202391cf9f8e65" +checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" dependencies = [ "once_cell", "wasm-bindgen", @@ -1249,9 +1248,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.177" +version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "libm" @@ -1288,9 +1287,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.28" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lru-slab" @@ -1337,9 +1336,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", "wasi 0.11.1+wasi-snapshot-preview1", @@ -1455,16 +1454,16 @@ dependencies = [ [[package]] name = "oauth2-test-server" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bb78cf155f91eba1d99533e49aafc31f5e7e42b9964d2c0c8470d6641accb54" +checksum = "e66b9483c4680a03f8f3a414e02d9e2b2d12702946d2fd05d58c3da4406630d2" dependencies = [ "axum", "base64 0.21.7", "chrono", "colored", "futures", - "http 1.3.1", + "http 1.4.0", "jsonwebtoken", "once_cell", "rand 0.8.5", @@ -1893,9 +1892,9 @@ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "reqwest" -version = "0.12.24" +version = "0.12.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" +checksum = "b6eff9328d40131d43bd911d42d79eb6a47312002a4daefc9e37f17e74a7701a" dependencies = [ "base64 0.22.1", "bytes", @@ -1905,7 +1904,7 @@ dependencies = [ "futures-core", "futures-util", "h2 0.4.12", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "hyper 1.8.1", @@ -1931,7 +1930,7 @@ dependencies = [ "tokio-rustls", "tokio-util", "tower", - "tower-http 0.6.6", + "tower-http 0.6.8", "tower-service", "url", "wasm-bindgen", @@ -1989,7 +1988,7 @@ dependencies = [ "async-trait", "base64 0.22.1", "bytes", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "nanoid", @@ -2021,9 +2020,9 @@ dependencies = [ [[package]] name = "rust-mcp-schema" -version = "0.7.5" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba217e6fcb043bba9e194209bff92c35294093187504d1443832ca2051816753" +checksum = "8b6cf84194ba1c1703c7ad0a6730b483f1a34dd32057e8e7226387da3f876591" dependencies = [ "serde", "serde_json", @@ -2039,7 +2038,7 @@ dependencies = [ "base64 0.22.1", "bytes", "futures", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "hyper 1.8.1", @@ -2124,9 +2123,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.13.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94182ad936a0c91c324cd46c6511b9510ed16af436d7b5bab34beab0afd55f7a" +checksum = "708c0f9d5f54ba0272468c1d306a52c495b31fa155e91bc25371e6df7996908c" dependencies = [ "web-time", "zeroize", @@ -2375,6 +2374,8 @@ dependencies = [ "serde_json", "thiserror 2.0.17", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] @@ -2761,7 +2762,7 @@ checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ "bitflags", "bytes", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "pin-project-lite", @@ -2772,14 +2773,14 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ "bitflags", "bytes", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "iri-string", "pin-project-lite", @@ -2802,9 +2803,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "2d15d90a0b5c19378952d479dc858407149d7bb45a14de0142f6c534b16fc647" dependencies = [ "log", "pin-project-lite", @@ -2814,9 +2815,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -2825,9 +2826,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "7a04e24fab5c89c6a36eb8558c9656f30d81de51dfa4d3b45f26b21d61fa0a6c" dependencies = [ "once_cell", "valuable", @@ -2846,9 +2847,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.20" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", @@ -2918,9 +2919,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.18.1" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" dependencies = [ "getrandom 0.3.4", "js-sys", @@ -2983,9 +2984,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da95793dfc411fbbd93f5be7715b0578ec61fe87cb1a42b12eb625caa5c5ea60" +checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" dependencies = [ "cfg-if", "once_cell", @@ -2996,9 +2997,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.55" +version = "0.4.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "551f88106c6d5e7ccc7cd9a16f312dd3b5d36ea8b4954304657d5dfba115d4a0" +checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" dependencies = [ "cfg-if", "js-sys", @@ -3009,9 +3010,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04264334509e04a7bf8690f2384ef5265f05143a4bff3889ab7a3269adab59c2" +checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3019,9 +3020,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420bc339d9f322e562942d52e115d57e950d12d88983a14c79b86859ee6c7ebc" +checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" dependencies = [ "bumpalo", "proc-macro2", @@ -3032,9 +3033,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76f218a38c84bcb33c25ec7059b07847d465ce0e0a76b995e134a45adcb6af76" +checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" dependencies = [ "unicode-ident", ] @@ -3054,9 +3055,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a1f95c0d03a47f4ae1f7a64643a6bb97465d9b740f0fa8f90ea33915c99a9a1" +checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" dependencies = [ "js-sys", "wasm-bindgen", @@ -3375,18 +3376,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.28" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43fa6694ed34d6e57407afbccdeecfa268c470a7d2a5b0cf49ce9fcc345afb90" +checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.28" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c640b22cd9817fae95be82f0d2f90b11f7605f6c319d16705c459b27ac2cbc26" +checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 3b7f98c..cd0dcc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,8 +30,7 @@ rust-mcp-macros = { version = "0.5.3", path = "crates/rust-mcp-macros", default- rust-mcp-extra = { version="0.1.0", path = "crates/rust-mcp-extra", default-features = false } # External crates -rust-mcp-schema = { version = "0.7", default-features = false } - +rust-mcp-schema = { version="0.9", default-features = false } futures = { version = "0.3" } tokio = { version = "1.4", features = ["full"] } diff --git a/README.md b/README.md index d92d964..715f280 100644 --- a/README.md +++ b/README.md @@ -11,26 +11,21 @@ [Hello World MCP Server ](examples/hello-world-mcp-server-stdio) -A high-performance, asynchronous toolkit for building MCP servers and clients. -Focus on your app's logic while **rust-mcp-sdk** takes care of the rest! -**rust-mcp-sdk** provides the necessary components for developing both servers and clients in the MCP ecosystem. -Leveraging the [rust-mcp-schema](https://github.com/rust-mcp-stack/rust-mcp-schema) crate simplifies the process of building robust and reliable MCP servers and clients, ensuring consistency and minimizing errors in data handling and message processing. +A high-performance, asynchronous Rust toolkit for building MCP servers and clients. +Focus on your application logic - rust-mcp-sdk handles the protocol, transports, and the rest! +This SDK fully implements the latest MCP protocol version ([2025-11-25](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema)), with backward compatibility built-in. `rust-mcp-sdk` provides the necessary components for developing both servers and clients in the MCP ecosystem. It leverages the [rust-mcp-schema](https://crates.io/crates/rust-mcp-schema) crate for type-safe schema objects and includes powerful procedural macros for tools and user input elicitation. -**rust-mcp-sdk** supports all three official versions of the MCP protocol. -By default, it uses the **2025-06-18** version, but earlier versions can be enabled via Cargo features. - -πŸš€ The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. - - -**Features** -- βœ… Stdio, SSE and Streamable HTTP Support -- βœ… Supports multiple MCP protocol versions +**Key Features** +- βœ… Latest MCP protocol specification supported: 2025-11-25 +- βœ… Transports:Stdio, Streamable HTTP, and backward-compatible SSE support +- βœ… Lightweight Axum-based server for Streamable HTTP and SSE +- βœ… Multi-client concurrency - βœ… DNS Rebinding Protection +- βœ… Resumability - βœ… Batch Messages - βœ… Streaming & non-streaming JSON response -- βœ… Resumability - βœ… OAuth Authentication for MCP Servers - βœ… [Remote Oauth Provider](crates/rust-mcp-sdk/src/auth/auth_provider/remote_auth_provider.rs) (for any provider with DCR support) - βœ… **Keycloak** Provider (via [rust-mcp-extra](crates/rust-mcp-extra/README.md#keycloak)) @@ -41,24 +36,26 @@ By default, it uses the **2025-06-18** version, but earlier versions can be enab **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents -- [Getting Started](#getting-started) +- [Quick Start](#quick-start) + - [Minimal MCP Server (Stdio)]([#minimal-mcp-server-stdio](#minimal-mcp-server-stdio)) + - [Minimal MCP Server (Streamable HTTP)](#minimal-mcp-server-streamable-http) + - [Minimal MCP Client (Stdio)](#minimal-mcp-client-stdio) - [Usage Examples](#usage-examples) - - [MCP Server (stdio)](#mcp-server-stdio) - - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - - [MCP Client (stdio)](#mcp-client-stdio) - - [MCP Client (Streamable HTTP)](#mcp-client-streamable-http) - - [MCP Client (sse)](#mcp-client-sse) -- [Authentication](#authentication) - [Macros](#macros) + - [mcp_tool](#mcp_tool) + - [tool_box](#-tool_box) + - [mcp_icon](#-mcp_icon) +- [Authentication](#authentication) + - [RemoteAuthProvider](#remoteauthprovider) + - [OAuthProxy](#oauthproxy) - [HyperServerOptions](#hyperserveroptions) - - [Security Considerations](#security-considerations) +- [Security Considerations](#security-considerations) - [Cargo features](#cargo-features) - [Available Features](#available-features) - - [MCP protocol versions with corresponding features](#mcp-protocol-versions-with-corresponding-features) - [Default Features](#default-features) - [Using Only the server Features](#using-only-the-server-features) - [Using Only the client Features](#using-only-the-client-features) -- [Choosing Between Standard and Core Handlers traits](#choosing-between-standard-and-core-handlers-traits) +- [Handler Traits](#handlers-traits) - [Choosing Between **ServerHandler** and **ServerHandlerCore**](#choosing-between-serverhandler-and-serverhandlercore) - [Choosing Between **ClientHandler** and **ClientHandlerCore**](#choosing-between-clienthandler-and-clienthandlercore) - [Projects using Rust MCP SDK](#projects-using-rust-mcp-sdk) @@ -67,330 +64,339 @@ By default, it uses the **2025-06-18** version, but earlier versions can be enab - [License](#license) -## Getting Started -If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) -## Usage Examples +## Quick Start -### MCP Server (stdio) + -Create a MCP server with a `tool` that will print a `Hello World!` message: +Add to your Cargo.toml: +```toml +[dependencies] +rust-mcp-sdk = "0.9.0" # Check crates.io for the latest version +``` + + + +## Minimal MCP Server (Stdio) +```rs +use async_trait::async_trait; +use rust_mcp_sdk::{*,error::SdkResult,macros,mcp_server::{server_runtime, ServerHandler},schema::*,}; + +// Define a mcp tool +#[macros::mcp_tool(name = "say_hello", description = "returns \"Hello from Rust MCP SDK!\" message ")] +#[derive(Debug, ::serde::Deserialize, ::serde::Serialize, macros::JsonSchema)] +pub struct SayHelloTool {} + +// define a custom handler +#[derive(Default)] +struct HelloHandler; + +// implement ServerHandler +#[async_trait] +impl ServerHandler for HelloHandler { + // Handles requests to list available tools. + async fn handle_list_tools_request( + &self, + _request: Option, + _runtime: std::sync::Arc, + ) -> std::result::Result { + Ok(ListToolsResult { + tools: vec![SayHelloTool::tool()], + meta: None, + next_cursor: None, + }) + } + // Handles requests to call a specific tool. + async fn handle_call_tool_request(&self, + params: CallToolRequestParams, + _runtime: std::sync::Arc, + ) -> std::result::Result { + if params.name == "say_hello" { + Ok(CallToolResult::text_content(vec!["Hello from Rust MCP SDK!".into()])) + } else { + Err(CallToolError::unknown_tool(params.name)) + } + } +} -```rust #[tokio::main] async fn main() -> SdkResult<()> { - - // STEP 1: Define server details and capabilities - let server_details = InitializeResult { - // server name and version + // Define server details and capabilities + let server_info = InitializeResult { server_info: Implementation { - name: "Hello World MCP Server".to_string(), - version: "0.1.0".to_string(), - title: Some("Hello World MCP Server".to_string()), - }, - capabilities: ServerCapabilities { - // indicates that server support mcp tools - tools: Some(ServerCapabilitiesTools { list_changed: None }), - ..Default::default() // Using default values for other fields + name: "hello-rust-mcp".into(), + version: "0.1.0".into(), + title: Some("Hello World MCP Server".into()), + description: Some("A minimal Rust MCP server".into()), + icons: vec![mcp_icon!(src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "light")], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, - meta: None, - instructions: Some("server instructions...".to_string()), - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + capabilities: ServerCapabilities { tools: Some(ServerCapabilitiesTools { list_changed: None }), ..Default::default() }, + protocol_version: ProtocolVersion::V2025_11_25.into(), + instructions: None, + meta:None }; - // STEP 2: create a std transport with default options let transport = StdioTransport::new(TransportOptions::default())?; - - // STEP 3: instantiate our custom handler for handling MCP messages - let handler = MyServerHandler {}; - - // STEP 4: create a MCP server - let server: ServerRuntime = server_runtime::create_server(server_details, transport, handler); - - // STEP 5: Start the server + let handler = HelloHandler::default().to_mcp_server_handler(); + let server = server_runtime::create_server(server_info, transport, handler); server.start().await - } ``` -See hello-world-mcp-server-stdio example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : - -![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) - -### MCP Server (Streamable HTTP) - -Creating an MCP server in `rust-mcp-sdk` with the `sse` transport allows multiple clients to connect simultaneously with no additional setup. -Simply create a Hyper Server using `hyper_server::create_server()` and pass in the same handler and HyperServerOptions. - - -πŸ’‘ By default, both **Streamable HTTP** and **SSE** transports are enabled for backward compatibility. To disable the SSE transport , set the `sse_support` to false in the `HyperServerOptions`. - +## Minimal MCP Server (Streamable HTTP) +Creating an MCP server in `rust-mcp-sdk` allows multiple clients to connect simultaneously with no additional setup. +The setup is nearly identical to the stdio example shown above. You only need to create a Hyper server via `hyper_server::create_server()` and pass in the same handler and `HyperServerOptions`. +πŸ’‘ If backward compatibility is required, you can enable **SSE** transport by setting `sse_support` to true in `HyperServerOptions`. ```rust - -// STEP 1: Define server details and capabilities -let server_details = InitializeResult { - // server name and version - server_info: Implementation { - name: "Hello World MCP Server".to_string(), - version: "0.1.0".to_string(), - title: Some("Hello World MCP Server".to_string()), - }, - capabilities: ServerCapabilities { - // indicates that server support mcp tools - tools: Some(ServerCapabilitiesTools { list_changed: None }), - ..Default::default() // Using default values for other fields - }, - meta: None, - instructions: Some("server instructions...".to_string()), - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), +use async_trait::async_trait; +use rust_mcp_sdk::{*,error::SdkResult,event_store::InMemoryEventStore,macros, + mcp_server::{hyper_server, HyperServerOptions, ServerHandler},schema::*, }; -// STEP 2: instantiate our custom handler for handling MCP messages -let handler = MyServerHandler {}; - -// STEP 3: instantiate HyperServer, providing `server_details` , `handler` and HyperServerOptions -let server = hyper_server::create_server( - server_details, - handler, - HyperServerOptions { - host: "127.0.0.1".to_string(), - sse_support: false, - event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability - ..Default::default() - }, -); - -// STEP 4: Start the server -server.start().await?; - -Ok(()) -``` - - -The implementation of `MyServerHandler` is the same regardless of the transport used and could be as simple as the following: - -```rust - -// STEP 1: Define a rust_mcp_schema::Tool ( we need one with no parameters for this example) -#[mcp_tool(name = "say_hello_world", description = "Prints \"Hello World!\" message")] -#[derive(Debug, Deserialize, Serialize, JsonSchema)] +// Define a mcp tool +#[macros::mcp_tool( + name = "say_hello", + description = "returns \"Hello from Rust MCP SDK!\" message " +)] +#[derive(Debug, ::serde::Deserialize, ::serde::Serialize, macros::JsonSchema)] pub struct SayHelloTool {} -// STEP 2: Implement ServerHandler trait for a custom handler -// For this example , we only need handle_list_tools_request() and handle_call_tool_request() methods. -pub struct MyServerHandler; +// define a custom handler +#[derive(Default)] +struct HelloHandler; +// implement ServerHandler #[async_trait] -impl ServerHandler for MyServerHandler { - // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { - - Ok(ListToolsResult { - tools: vec![SayHelloTool::tool()], - meta: None, - next_cursor: None, - }) - +impl ServerHandler for HelloHandler { + // Handles requests to list available tools. + async fn handle_list_tools_request( + &self, + _request: Option, + _runtime: std::sync::Arc, + ) -> std::result::Result { + Ok(ListToolsResult {tools: vec![SayHelloTool::tool()],meta: None,next_cursor: None}) } - - /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { - - if request.tool_name() == SayHelloTool::tool_name() { - Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) + // Handles requests to call a specific tool. + async fn handle_call_tool_request( + &self, + params: CallToolRequestParams, + _runtime: std::sync::Arc, + ) -> std::result::Result { + if params.name == "say_hello" {Ok(CallToolResult::text_content(vec!["Hello from Rust MCP SDK!".into()])) } else { - Err(CallToolError::unknown_tool(request.tool_name().to_string())) + Err(CallToolError::unknown_tool(params.name)) } - } } -``` - ---- - -πŸ‘‰ For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** -See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +#[tokio::main] +async fn main() -> SdkResult<()> { + // Define server details and capabilities + let server_info = InitializeResult { + server_info: Implementation { + name: "hello-rust-mcp".into(), + version: "0.1.0".into(), + title: Some("Hello World MCP Server".into()), + description: Some("A minimal Rust MCP server".into()), + icons: vec![mcp_icon!(src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "light")], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), + }, + capabilities: ServerCapabilities { tools: Some(ServerCapabilitiesTools { list_changed: None }), ..Default::default() }, + protocol_version: ProtocolVersion::V2025_11_25.into(), + instructions: None, + meta:None + }; -![mcp-server in rust](assets/examples/hello-world-server-streamable-http.gif) + let handler = HelloHandler::default().to_mcp_server_handler(); + let server = hyper_server::create_server( + server_info, + handler, + HyperServerOptions { + host: "127.0.0.1".to_string(), + event_store: Some(std::sync::Arc::new(InMemoryEventStore::default())), // enable resumability + ..Default::default() + }, + ); + server.start().await?; + Ok(()) +} +``` ---- -### MCP Client (stdio) +## Minimal MCP Client (Stdio) +Following is implementation of an MCP client that starts the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, displays the server's name, version, and list of tools provided by the server. -Create an MCP client that starts the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, displays the server's name, version, and list of tools, then uses the add tool provided by the server to sum 120 and 28, printing the result. ```rust +use async_trait::async_trait; +use rust_mcp_sdk::{*, error::SdkResult, + mcp_client::{client_runtime, ClientHandler}, + schema::*, +}; -// STEP 1: Custom Handler to handle incoming MCP Messages +// Custom Handler to handle incoming MCP Messages pub struct MyClientHandler; - #[async_trait] impl ClientHandler for MyClientHandler { - // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs + // To see all the trait methods you can override, + // check out: + // https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs } #[tokio::main] async fn main() -> SdkResult<()> { - - // Step2 : Define client details and capabilities + // Client details and capabilities let client_details: InitializeRequestParams = InitializeRequestParams { capabilities: ClientCapabilities::default(), client_info: Implementation { name: "simple-rust-mcp-client".into(), version: "0.1.0".into(), + description: None, + icons: vec![], + title: None, + website_url: None, }, - protocol_version: LATEST_PROTOCOL_VERSION.into(), + protocol_version: ProtocolVersion::V2025_11_25.into(), + meta: None, }; - // Step3 : Create a transport, with options to launch @modelcontextprotocol/server-everything MCP Server + // Create a transport, with options to launch @modelcontextprotocol/server-everything MCP Server let transport = StdioTransport::create_with_server_launch( - "npx", - vec![ "-y".to_string(), "@modelcontextprotocol/server-everything".to_string()], - None, TransportOptions::default() + "npx",vec!["-y".to_string(),"@modelcontextprotocol/server-everything@latest".to_string()], + None, + TransportOptions::default(), )?; - // STEP 4: instantiate our custom handler for handling MCP messages + // instantiate our custom handler for handling MCP messages let handler = MyClientHandler {}; - // STEP 5: create a MCP client - let client = client_runtime::create_client(client_details, transport, handler); - - // STEP 6: start the MCP client + // Create and start the MCP client + let client = client_runtime::create_client(client_details, transport, handler); client.clone().start().await?; + // use client methods to communicate with the MCP Server as you wish: - // STEP 7: use client methods to communicate with the MCP Server as you wish - + let server_version = client.server_version().unwrap(); + // Retrieve and display the list of tools available on the server - let server_version = client.server_version().unwrap(); - let tools = client.list_tools(None).await?.tools; - - println!("List of tools for {}@{}", server_version.name, server_version.version); - + let tools = client.request_tool_list(None).await?.tools; + println!( "List of tools for {}@{}",server_version.name, server_version.version); tools.iter().enumerate().for_each(|(tool_index, tool)| { - println!(" {}. {} : {}", - tool_index + 1, - tool.name, - tool.description.clone().unwrap_or_default() - ); + println!(" {}. {} : {}", tool_index + 1, tool.name, tool.description.clone().unwrap_or_default()); }); - println!("Call \"add\" tool with 100 and 28 ..."); - // Create a `Map` to represent the tool parameters - let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); - let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; - - // invoke the tool - let result = client.call_tool(request).await?; - - println!("{}",result.content.first().unwrap().as_text_content()?.text); - client.shut_down().await?; - Ok(()) } - ``` -Here is the output : +## Usage Examples -![rust-mcp-sdk-client-output](assets/examples/mcp-client-sample-code.jpg) +πŸ‘‰ For full examples (stdio, Streamable HTTP, clients, auth, etc.), see the [examples/](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples) directory. -> your results may vary slightly depending on the version of the MCP Server in use when you run it. +πŸ‘‰ If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) -### MCP Client (Streamable HTTP) -```rs +See [hello-world-mcp-server-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : -// STEP 1: Custom Handler to handle incoming MCP Messages -pub struct MyClientHandler; +![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) -#[async_trait] -impl ClientHandler for MyClientHandler { - // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs -} -#[tokio::main] -async fn main() -> SdkResult<()> { +## Macros +Enable with the `macros` feature. - // Step2 : Define client details and capabilities - let client_details: InitializeRequestParams = InitializeRequestParams { - capabilities: ClientCapabilities::default(), - client_info: Implementation { - name: "simple-rust-mcp-client-sse".to_string(), - version: "0.1.0".to_string(), - title: Some("Simple Rust MCP Client (SSE)".to_string()), - }, - protocol_version: LATEST_PROTOCOL_VERSION.into(), - }; +[rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) includes several helpful macros that simplify common tasks when building MCP servers and clients. For example, they can automatically generate tool specifications and tool schemas right from your structs, or assist with elicitation requests and responses making them completely type safe. - // Step 3: Create transport options to connect to an MCP server via Streamable HTTP. - let transport_options = StreamableTransportOptions { - mcp_url: MCP_SERVER_URL.to_string(), - request_options: RequestOptions { - ..RequestOptions::default() - }, - }; +### β—Ύ`mcp_tool` +Generate a [Tool](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/struct.Tool.html) from a struct, with rich metadata (icons, execution hints, etc.). - // STEP 4: instantiate the custom handler that is responsible for handling MCP messages - let handler = MyClientHandler {}; +example usage: +```rs +#[mcp_tool( + name = "write_file", + title = "Write File Tool", + description = "Create a new file or completely overwrite an existing file with new content.", + destructive_hint = false idempotent_hint = false open_world_hint = false read_only_hint = false, + meta = r#"{ "key" : "value", "string_meta" : "meta value", "numeric_meta" : 15}"#, + execution(task_support = "optional"), + icons = [(src = "https:/website.com/write.png", mime_type = "image/png", sizes = ["128x128"], theme = "light")] +)] +#[derive(rust_mcp_macros::JsonSchema)] +pub struct WriteFileTool { + /// The target file's path for writing content. + pub path: String, + /// The string content to be written to the file + pub content: String, +} +``` - // STEP 5: create the client with transport options and the handler - let client = client_runtime::with_transport_options(client_details, transport_options, handler); +πŸ“ For complete documentation, example usage, and a list of all available attributes, please refer to https://crates.io/crates/rust-mcp-macros. - // STEP 6: start the MCP client - client.clone().start().await?; +### β—Ύ `tool_box!()` +Automatically generates an enum based on the provided list of tools, making it easier to organize and manage them, especially when your application includes a large number of tools. - // STEP 7: use client methods to communicate with the MCP Server as you wish +```rs +tool_box!(GreetingTools, [SayHelloTool, SayGoodbyeTool]); - // Retrieve and display the list of tools available on the server - let server_version = client.server_version().unwrap(); - let tools = client.list_tools(None).await?.tools; - println!("List of tools for {}@{}", server_version.name, server_version.version); +let tools: Vec = GreetingTools::tools(); +`` - tools.iter().enumerate().for_each(|(tool_index, tool)| { - println!(" {}. {} : {}", - tool_index + 1, - tool.name, - tool.description.clone().unwrap_or_default() - ); - }); +πŸ’» For a real-world example, check out [tools/](https://github.com/rust-mcp-stack/rust-mcp-filesystem/tree/main/src/tools) and +[handle_call_tool_request(...)](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L195) in [rust-mcp-filesystem](https://github.com/rust-mcp-stack/rust-mcp-filesystem) project - println!("Call \"add\" tool with 100 and 28 ..."); - // Create a `Map` to represent the tool parameters - let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); - let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; +### β—Ύ [mcp_elicit](https://crates.io/crates/rust-mcp-macros) +Generates type-safe elicitation (Form or URL mode) for user input. - // invoke the tool - let result = client.call_tool(request).await?; +example usage: +```rs +#[mcp_elicit(message = "Please enter your info", mode = form)] +#[derive(JsonSchema)] +pub struct UserInfo { + #[json_schema(title = "Name", min_length = 5, max_length = 100)] + pub name: String, + #[json_schema(title = "Email", format = "email")] + pub email: Option, + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + #[json_schema(title = "Tags")] + pub tags: Vec, +} - println!("{}",result.content.first().unwrap().as_text_content()?.text); +// Sends a request to the client asking the user to provide input +let result: ElicitResult = server.request_elicitation(UserInfo::elicit_request_params()).await?; - client.shut_down().await?; +// Convert result.content into a UserInfo instance +let user_info = UserInfo::from_elicit_result_content(result.content)?; - Ok(()) +println!("name: {}", user_info.name); +println!("age: {}", user_info.age); +println!("email: {}",user.email.clone().unwrap_or("not provider".into())); +println!("tags: {}", user_info.tags.join(",")); ``` -πŸ‘‰ see [examples/simple-mcp-client-streamable-http](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-streamable-http) for a complete working example. +πŸ“ For complete documentation, example usage, and a list of all available attributes, please refer to https://crates.io/crates/rust-mcp-macros. +### β—Ύ `mcp_icon!()` +A convenient icon builder for implementations and tools, offering full attribute support including theme, size, mime, and more. -### MCP Client (sse) -Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical to the [stdio example](#mcp-client-stdio) , with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: - -```diff -- let transport = StdioTransport::create_with_server_launch( -- "npx", -- vec![ "-y".to_string(), "@modelcontextprotocol/server-everything".to_string()], -- None, TransportOptions::default() --)?; -+ let transport = ClientSseTransport::new(MCP_SERVER_URL, ClientSseTransportOptions::default())?; +example usage: +```rs +let icon: crate::schema::Icon = mcp_icon!( + src = "http://website.com/icon.png", + mime_type = "image/png", + sizes = ["64x64"], + theme = "dark" + ); ``` -πŸ‘‰ see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. - - ## Authentication MCP server can verify tokens issued by other systems, integrate with external identity providers, or manage the entire authentication process itself. Each option offers a different balance of simplicity, security, and control. @@ -404,120 +410,12 @@ MCP server can verify tokens issued by other systems, integrate with external id - [WorkOS autn example](crates/rust-mcp-extra/README.md#workos-authkit) - ### OAuthProxy OAuthProxy enables authentication with OAuth providers that don’t support Dynamic Client Registration (DCR).It accepts any client registration request, handles the DCR on your server side and then uses your pre-registered app credentials upstream.The proxy also forwards callbacks, allowing dynamic redirect URIs to work with providers that require fixed ones. > ⚠️ OAuthProxy support is still in development, please use RemoteAuthProvider for now. - - -## Macros -[rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) includes several helpful macros that simplify common tasks when building MCP servers and clients. For example, they can automatically generate tool specifications and tool schemas right from your structs, or assist with elicitation requests and responses making them completely type safe. - -> To use these macros, ensure the `macros` feature is enabled in your Cargo.toml. - -### mcp_tool -`mcp_tool` is a procedural macro attribute that helps generating rust_mcp_schema::Tool from a struct. - -Usage example: -```rust -#[mcp_tool( - name = "move_file", - title="Move File", - description = concat!("Move or rename files and directories. Can move files between directories ", -"and rename them in a single operation. If the destination exists, the ", -"operation will fail. Works across different directories and can be used ", -"for simple renaming within the same directory. ", -"Both source and destination must be within allowed directories."), - destructive_hint = false, - idempotent_hint = false, - open_world_hint = false, - read_only_hint = false -)] -#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug, JsonSchema)] -pub struct MoveFileTool { - /// The source path of the file to move. - pub source: String, - /// The destination path to move the file to. - pub destination: String, -} - -// Now we can call `tool()` method on it to get a Tool instance -let rust_mcp_sdk::schema::Tool = MoveFileTool::tool(); - -``` - -πŸ’» For a real-world example, check out any of the tools available at: https://github.com/rust-mcp-stack/rust-mcp-filesystem/tree/main/src/tools - - -### tool_box -`tool_box` generates an enum from a provided list of tools, making it easier to organize and manage them, especially when your application includes a large number of tools. - -It accepts an array of tools and generates an enum where each tool becomes a variant of the enum. - -Generated enum has a `tools()` function that returns a `Vec` , and a `TryFrom` trait implementation that could be used to convert a ToolRequest into a Tool instance. - -Usage example: -```rust - // Accepts an array of tools and generates an enum named `FileSystemTools`, - // where each tool becomes a variant of the enum. - tool_box!(FileSystemTools, [ReadFileTool, MoveFileTool, SearchFilesTool]); - - // now in the app, we can use the FileSystemTools, like: - let all_tools: Vec = FileSystemTools::tools(); -``` - -πŸ’» To see a real-world example of that please see : -- `tool_box` macro usage: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs) -- using `tools()` in list tools request : [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L67) -- using `try_from` in call tool_request: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L100) - - - -### mcp_elicit -The `mcp_elicit` macro generates implementations for the annotated struct to facilitate data elicitation. It enables struct to generate `ElicitRequestedSchema` and also parsing a map of field names to `ElicitResultContentValue` values back into the struct, supporting both required and optional fields. The generated implementation includes: -- A `message()` method returning the elicitation message as a string. -- A `requested_schema()` method returning an `ElicitRequestedSchema` based on the struct’s JSON schema. -- A `from_content_map()` method to convert a map of `ElicitResultContentValue` values into a struct instance. - -### Attributes - -- `message` - An optional string (or `concat!(...)` expression) to prompt the user or system for input. Defaults to an empty string if not provided. - -Usage example: -```rust -// A struct that could be used to send elicit request and get the input from the user -#[mcp_elicit(message = "Please enter your info")] -#[derive(JsonSchema)] -pub struct UserInfo { - #[json_schema( - title = "Name", - description = "The user's full name", - min_length = 5, - max_length = 100 - )] - pub name: String, - /// Is user a student? - #[json_schema(title = "Is student?", default = true)] - pub is_student: Option, - - /// User's favorite color - pub favorate_color: Colors, -} - -// send a Elicit Request , ask for UserInfo data and convert the result back to a valid UserInfo instance -let result: ElicitResult = server - .elicit_input(UserInfo::message(), UserInfo::requested_schema()) - .await?; - -// Create a UserInfo instance using data provided by the user on the client side -let user_info = UserInfo::from_content_map(result.content)?; - -``` -πŸ’» For mre info please see : -- https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/crates/rust-mcp-macros ## HyperServerOptions @@ -531,89 +429,22 @@ A typical example of creating a HyperServer that exposes the MCP server via Stre let server = hyper_server::create_server( server_details, - handler, + handler.to_mcp_server_handler(), HyperServerOptions { host: "127.0.0.1".to_string(), - enable_ssl: true, + port: 8080, + event_store: Some(std::sync::Arc::new(InMemoryEventStore::default())), // enable resumability + auth: Some(Arc::new(auth_provider)), // enable authentication + sse_support: false, ..Default::default() }, ); server.start().await?; - ``` -Here is a list of available options with descriptions for configuring the HyperServer: -```rs - -pub struct HyperServerOptions { - /// Hostname or IP address the server will bind to (default: "127.0.0.1") - pub host: String, - - /// Hostname or IP address the server will bind to (default: "8080") - pub port: u16, - - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>>, - - /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) - pub custom_streamable_http_endpoint: Option, - - /// Shared transport configuration used by the server - pub transport_options: Arc, - - /// Event store for resumability support - /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages - pub event_store: Option>, - - /// This setting only applies to streamable HTTP. - /// If true, the server will return JSON responses instead of starting an SSE stream. - /// This can be useful for simple request/response scenarios without streaming. - /// Default is false (SSE streams are preferred). - pub enable_json_response: Option, - - /// Interval between automatic ping messages sent to clients to detect disconnects - pub ping_interval: Duration, - - /// Enables SSL/TLS if set to `true` - pub enable_ssl: bool, - - /// Path to the SSL/TLS certificate file (e.g., "cert.pem"). - /// Required if `enable_ssl` is `true`. - pub ssl_cert_path: Option, - - /// Path to the SSL/TLS private key file (e.g., "key.pem"). - /// Required if `enable_ssl` is `true`. - pub ssl_key_path: Option, +πŸ“ Refer to [HyperServerOptions](https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/hyper_servers/server.rs#L43) for a complete overview of HyperServerOptions attributes and options. - /// List of allowed host header values for DNS rebinding protection. - /// If not specified, host validation is disabled. - pub allowed_hosts: Option>, - - /// List of allowed origin header values for DNS rebinding protection. - /// If not specified, origin validation is disabled. - pub allowed_origins: Option>, - - /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). - /// Default is false for backwards compatibility. - pub dns_rebinding_protection: bool, - - /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) - pub sse_support: bool, - - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - /// Applicable only if sse_support is true - pub custom_sse_endpoint: Option, - - /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) - /// Applicable only if sse_support is true - pub custom_messages_endpoint: Option, - - /// Optional authentication provider for protecting MCP server. - pub auth: Option>, -} - -``` ### Security Considerations @@ -637,28 +468,19 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. - `sse`: Enables support for the `Server-Sent Events (SSE)` transport. - `streamable-http`: Enables support for the `Streamable HTTP` transport. - - `stdio`: Enables support for the `standard input/output (stdio)` transport. - `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. -#### MCP Protocol Versions with Corresponding Features - -- `2025_06_18` : Activates MCP Protocol version 2025-06-18 (enabled by default) -- `2025_03_26` : Activates MCP Protocol version 2025-03-26 -- `2024_11_05` : Activates MCP Protocol version 2024-11-05 - -> Note: MCP protocol versions are mutually exclusive-only one can be active at any given time. - ### Default Features -When you add rust-mcp-sdk as a dependency without specifying any features, all features are included, with the latest MCP Protocol version enabled by default: +When you add rust-mcp-sdk as a dependency without specifying any features, all features are enabled by default ```toml [dependencies] -rust-mcp-sdk = "0.2.0" +rust-mcp-sdk = "0.9.0" ``` diff --git a/assets/rust-mcp-icon.png b/assets/rust-mcp-icon.png new file mode 100644 index 0000000..189ea43 Binary files /dev/null and b/assets/rust-mcp-icon.png differ diff --git a/crates/rust-mcp-extra/Cargo.toml b/crates/rust-mcp-extra/Cargo.toml index 3c3b438..e61a4e2 100644 --- a/crates/rust-mcp-extra/Cargo.toml +++ b/crates/rust-mcp-extra/Cargo.toml @@ -13,7 +13,7 @@ rust-version = { workspace = true } exclude = ["assets/", "tests/"] [dependencies] -rust-mcp-sdk = { version = "0.7.4" , path = "../rust-mcp-sdk", default-features = false, features=["server","2025_06_18","auth","hyper-server","macros"] } +rust-mcp-sdk = { version = "0.7.4" , path = "../rust-mcp-sdk", default-features = false, features=["server","auth","hyper-server","macros"] } base64 = {workspace = true, optional=true} url= {workspace = true, optional=true} nanoid = {version="0.4", optional=true} diff --git a/crates/rust-mcp-extra/examples/common/handler.rs b/crates/rust-mcp-extra/examples/common/handler.rs index 5fc0714..8f8f8fd 100644 --- a/crates/rust-mcp-extra/examples/common/handler.rs +++ b/crates/rust-mcp-extra/examples/common/handler.rs @@ -1,16 +1,14 @@ -use std::sync::Arc; - +use crate::common::tool::ShowAuthInfo; use async_trait::async_trait; use rust_mcp_sdk::{ mcp_server::ServerHandler, schema::{ - schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, - ListToolsResult, RpcError, + schema_utils::CallToolError, CallToolRequestParams, CallToolResult, ListToolsResult, + PaginatedRequestParams, RpcError, }, McpServer, }; - -use crate::common::tool::ShowAuthInfo; +use std::sync::Arc; pub struct McpServerHandler; #[async_trait] @@ -18,7 +16,7 @@ impl ServerHandler for McpServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult async fn handle_list_tools_request( &self, - _request: ListToolsRequest, + _request: Option, _runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { @@ -31,16 +29,16 @@ impl ServerHandler for McpServerHandler { /// Handles incoming CallToolRequest and processes it using the appropriate tool. async fn handle_call_tool_request( &self, - request: CallToolRequest, + params: CallToolRequestParams, runtime: Arc, ) -> std::result::Result { - if request.params.name.eq(&ShowAuthInfo::tool_name()) { + if params.name.eq(&ShowAuthInfo::tool_name()) { let tool = ShowAuthInfo::default(); tool.call_tool(runtime.auth_info_cloned().await) } else { Err(CallToolError::from_message(format!( "Tool \"{}\" does not exists or inactive!", - request.params.name, + params.name, ))) } } diff --git a/crates/rust-mcp-extra/examples/common/utils.rs b/crates/rust-mcp-extra/examples/common/utils.rs index 6889b56..5091b31 100644 --- a/crates/rust-mcp-extra/examples/common/utils.rs +++ b/crates/rust-mcp-extra/examples/common/utils.rs @@ -1,6 +1,9 @@ -use rust_mcp_sdk::schema::{ - Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, - LATEST_PROTOCOL_VERSION, +use rust_mcp_sdk::{ + mcp_icon, + schema::{ + Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, + LATEST_PROTOCOL_VERSION, + }, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -10,6 +13,16 @@ pub fn create_server_info(server_name: &str) -> InitializeResult { name: server_name.to_string(), version: "0.1.0".to_string(), title: Some(server_name.to_string()), + description: Some(server_name.to_string()), + icons: vec![ + mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + ) + ], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".to_string()), }, capabilities: ServerCapabilities { tools: Some(ServerCapabilitiesTools { list_changed: None }), diff --git a/crates/rust-mcp-extra/examples/keycloak-auth.rs b/crates/rust-mcp-extra/examples/keycloak-auth.rs index b4b191b..58bd9e1 100644 --- a/crates/rust-mcp-extra/examples/keycloak-auth.rs +++ b/crates/rust-mcp-extra/examples/keycloak-auth.rs @@ -7,6 +7,7 @@ use rust_mcp_extra::auth_provider::keycloak::{KeycloakAuthOptions, KeycloakAuthP use rust_mcp_sdk::{ error::SdkResult, mcp_server::{hyper_server, HyperServerOptions}, + ToMcpServerHandler, }; use std::{env, sync::Arc}; @@ -31,7 +32,7 @@ async fn main() -> SdkResult<()> { let server = hyper_server::create_server( server_details, - handler, + handler.to_mcp_server_handler(), HyperServerOptions { host: "localhost".to_string(), port: 3000, diff --git a/crates/rust-mcp-extra/examples/scalekit-auth.rs b/crates/rust-mcp-extra/examples/scalekit-auth.rs index 8fd625f..cf76efb 100644 --- a/crates/rust-mcp-extra/examples/scalekit-auth.rs +++ b/crates/rust-mcp-extra/examples/scalekit-auth.rs @@ -6,7 +6,9 @@ use crate::common::{ use rust_mcp_extra::auth_provider::scalekit::{ScalekitAuthOptions, ScalekitAuthProvider}; use rust_mcp_sdk::{ error::SdkResult, + event_store::InMemoryEventStore, mcp_server::{hyper_server, HyperServerOptions}, + ToMcpServerHandler, }; use std::{env, sync::Arc}; @@ -32,10 +34,11 @@ async fn main() -> SdkResult<()> { let server = hyper_server::create_server( server_details, - handler, + handler.to_mcp_server_handler(), HyperServerOptions { host: "127.0.0.1".to_string(), - port: 3000, + port: 8080, + event_store: Some(std::sync::Arc::new(InMemoryEventStore::default())), // enable resumability auth: Some(Arc::new(auth_provider)), // enable authentication sse_support: false, ..Default::default() diff --git a/crates/rust-mcp-extra/examples/workos-auth.rs b/crates/rust-mcp-extra/examples/workos-auth.rs index 01d980b..c948e90 100644 --- a/crates/rust-mcp-extra/examples/workos-auth.rs +++ b/crates/rust-mcp-extra/examples/workos-auth.rs @@ -7,6 +7,7 @@ use rust_mcp_extra::auth_provider::work_os::{WorkOSAuthOptions, WorkOsAuthProvid use rust_mcp_sdk::{ error::SdkResult, mcp_server::{hyper_server, HyperServerOptions}, + ToMcpServerHandler, }; use std::{env, sync::Arc}; @@ -29,7 +30,7 @@ async fn main() -> SdkResult<()> { let server = hyper_server::create_server( server_details, - handler, + handler.to_mcp_server_handler(), HyperServerOptions { host: "127.0.0.1".to_string(), port: 3000, diff --git a/crates/rust-mcp-macros/Cargo.toml b/crates/rust-mcp-macros/Cargo.toml index 58490d4..424e014 100644 --- a/crates/rust-mcp-macros/Cargo.toml +++ b/crates/rust-mcp-macros/Cargo.toml @@ -18,13 +18,12 @@ description = "A procedural macro, part of the rust-mcp-sdk ecosystem, that deri [dependencies] serde_json = { workspace = true } serde = { version = "1.0", features = ["derive"] } -syn = "2.0" +syn = {version="2.0", features = ["full", "extra-traits","parsing"]} quote = "1.0" proc-macro2 = "1.0" - [dev-dependencies] -rust-mcp-schema = { workspace = true, default-features = false } +rust-mcp-schema = { workspace = true , features=["latest","schema_utils"]} [lints] workspace = true @@ -32,18 +31,7 @@ workspace = true [lib] proc-macro = true - [features] # defalt features -default = ["2025_06_18"] # Default features - -# activates the latest MCP schema version, this will be updated once a new version of schema is published -latest = ["2025_06_18"] - -# enables mcp schema version 2025_06_18 -2025_06_18 = ["rust-mcp-schema/2025_06_18", "rust-mcp-schema/schema_utils"] -# enables mcp schema version 2025_03_26 -2025_03_26 = ["rust-mcp-schema/2025_03_26", "rust-mcp-schema/schema_utils"] -# enables mcp schema version 2024_11_05 -2024_11_05 = ["rust-mcp-schema/2024_11_05", "rust-mcp-schema/schema_utils"] +default = [] sdk = [] diff --git a/crates/rust-mcp-macros/README.md b/crates/rust-mcp-macros/README.md index fc463cd..57b000d 100644 --- a/crates/rust-mcp-macros/README.md +++ b/crates/rust-mcp-macros/README.md @@ -1,195 +1,186 @@ -# rust-mcp-macros. +# rust-mcp-macros + +`rust-mcp-macros` provides procedural macros for the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) ecosystem. These macros simplify the generation of `tools` and `elicitation` schemas compatible with the latest MCP protocol specifications. + + +The available macros are: + +[mcp_tool](#mcp_tool-macro): Generates a [rust_mcp_schema::Tool](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/struct.Tool.html) instance from a struct. +[mcp_elicit](#mcp_elicit): Generates elicitation logic for gathering user input based on a struct's schema, supporting [Form](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/struct.ElicitRequestFormParams.html) and [URL](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/struct.ElicitRequestUrlParams.html) modes. +[derive(JsonSchema)]: Derives a JSON Schema representation for structs and enums, used by the other macros for schema generation. + +These macros rely on [rust_mcp_schema](https://crates.io/crates/rust-mcp-schema) and serde_json for schema handling. ## mcp_tool Macro +A procedural macro to generate a [rust_mcp_schema::Tool](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/struct.Tool.html) instance from a struct. The struct must derive **JsonSchema**. + -A procedural macro, part of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) ecosystem, to generate `rust_mcp_schema::Tool` instance from a struct. +### Generated methods: -The `mcp_tool` macro generates an implementation for the annotated struct that includes: +- `tool_name()`: Returns the tool's name. +- `tool()`: Returns a [rust_mcp_schema::Tool](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/struct.Tool.html) with name, description, input schema, and optional metadata/annotations. +- `request_params()`: Returns a [CallToolRequestParams](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/struct.CallToolRequestParams.html) pre-initialized with the tool's name, ready for building a tool call via the builder pattern. -- A `tool_name()` method returning the tool's name as a string. -- A `tool()` method returning a `rust_mcp_schema::Tool` instance with the tool's name, - description, and input schema derived from the struct's fields. -## Attributes +### Attributes -- `name` - The name of the tool (required, non-empty string). +- `name`: Required, non-empty string for the tool's name. +- `description`: Required, a full and detailed description of the tool’s functionality. +- `title`: Optional human readable title for the tools. - `description` - A description of the tool (required, non-empty string). -- `title` - An optional human-readable and easily understood title. - `meta` - An optional JSON string that provides additional metadata for the tool. +- `execution`: Optional, controls task support. Accepted values are "required", "optional", and "forbidden". +- `icons`: Optional array of icons with src (required), mime_type, sizes (array of strings), theme ("light" or "dark"). - `destructive_hint` – Optional boolean, indicates whether the tool may make destructive changes to its environment. - `idempotent_hint` – Optional boolean, indicates whether repeated calls with the same input have the same effect. - `open_world_hint` – Optional boolean, indicates whether the tool can interact with external or unknown entities. - `read_only_hint` – Optional boolean, indicates whether the tool makes no modifications to its environment. - -## Usage Example +### Usage Example ```rust +use rust_mcp_macros::{mcp_tool, JsonSchema}; +use rust_mcp_schema::Tool; #[mcp_tool( - name = "write_file", - title = "Write File Tool" - description = "Create a new file or completely overwrite an existing file with new content." - destructive_hint = false - idempotent_hint = false - open_world_hint = false - read_only_hint = false - meta = r#"{ - "key" : "value", - "string_meta" : "meta value", - "numeric_meta" : 15 - }"# + name = "write_file", + title = "Write File Tool", + description = "Create or overwrite a file with content.", + destructive_hint = false, + idempotent_hint = false, + open_world_hint = false, + read_only_hint = false, + execution(task_support = "optional"), + icons = [ + (src = "https:/mywebsite.com/write.png", mime_type = "image/png", sizes = ["128x128"], theme = "light"), + (src = "https:/mywebsite.com/write_dark.svg", mime_type = "image/svg+xml", sizes = ["64x64","128x128"], theme = "dark") + ], + meta = r#"{"key": "value"}"# )] -#[derive(rust_mcp_macros::JsonSchema)] +#[derive(JsonSchema)] pub struct WriteFileTool { - /// The target file's path for writing content. + /// The target file's path. pub path: String, /// The string content to be written to the file pub content: String, } -fn main() { - - assert_eq!(WriteFileTool::tool_name(), "write_file"); - - let tool: rust_mcp_schema::Tool = WriteFileTool::tool(); - assert_eq!(tool.name, "write_file"); - assert_eq!(tool.title.as_ref().unwrap(), "Write File Tool"); - assert_eq!( tool.description.unwrap(),"Create a new file or completely overwrite an existing file with new content."); - - let meta: &Map = tool.meta.as_ref().unwrap(); - assert_eq!( - meta.get("key").unwrap(), - &Value::String("value".to_string()) - ); - - let schema_properties = tool.input_schema.properties.unwrap(); - assert_eq!(schema_properties.len(), 2); - assert!(schema_properties.contains_key("path")); - assert!(schema_properties.contains_key("content")); - - // get the `content` prop from schema - let content_prop = schema_properties.get("content").unwrap(); - - // assert the type - assert_eq!(content_prop.get("type").unwrap(), "string"); - // assert the description - assert_eq!( - content_prop.get("description").unwrap(), - "The string content to be written to the file" - ); +WriteFileTool::request_params().with_arguments( + json!({"path":"./test.txt","content":"hello tool"}) + .as_object() + .unwrap() + .clone(), +) + +// send a call_tool requeest: +let result = client.request_tool_call( WriteFileTool::request_params().with_arguments( + json!({"path":"./test.txt","content":"hello tool"}).as_object().unwrap().clone(), +))?; + +// Handle ListToolsRequest, return list of available tools as ListToolsResult +async fn handle_list_tools_request( + &self, + request: Option, + runtime: Arc, +) -> std::result::Result { + Ok(ListToolsResult { + meta: None, + next_cursor: None, + tools: vec![WriteFileTool::tool()], + }) } ``` +## mcp_elicit Macro - -**Note**: The following attributes are available only in version `2025_03_26` and later of the MCP Schema, and their values will be used in the [annotations](https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2025_03_26/mcp_schema.rs#L5557) attribute of the *[Tool struct](https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2025_03_26/mcp_schema.rs#L5554-L5566). - -- `destructive_hint` -- `idempotent_hint` -- `open_world_hint` -- `read_only_hint` - - - +The `mcp_elicit` macro generates implementations for eliciting user input based on the struct's schema. The struct must derive **JsonSchema**. It supports two modes: **form** (default) for schema-based forms and **url** for redirecting the user to an external URL to collect input. -## mcp_elicit Macro +### Generated methods: -The `mcp_elicit` macro generates implementations for the annotated struct to facilitate data elicitation. It enables struct to generate `ElicitRequestedSchema` and also parsing a map of field names to `ElicitResultContentValue` values back into the struct, supporting both required and optional fields. The generated implementation includes: +- `message()`: Returns the elicitation message. +- `elicit_request_params(elicitation_id)`: Returns [ElicitRequestParams](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/struct.ElicitRequestUrlParams.html) (FormParams or UrlParams based on mode). +- `from_elicit_result_content(content)`: Parses user input back into the struct. -- A `message()` method returning the elicitation message as a string. -- A `requested_schema()` method returning an `ElicitRequestedSchema` based on the struct’s JSON schema. -- A `from_content_map()` method to convert a map of `ElicitResultContentValue` values into a struct instance. ### Attributes -- `message` - An optional string (or `concat!(...)` expression) to prompt the user or system for input. Defaults to an empty string if not provided. +- `message` : Optional string (or concat!(...)), defaults to empty. +- `mode`: Optional, elicitation mode ("form"|"URL), defaults to form. +- `url` = "https://example.com/form": Required if mode = url. ### Supported Field Types -- `String`: Maps to `ElicitResultContentValue::String`. -- `bool`: Maps to `ElicitResultContentValue::Boolean`. -- `i32`: Maps to `ElicitResultContentValue::Integer` (with bounds checking). -- `i64`: Maps to `ElicitResultContentValue::Integer`. -- `enum` Only simple enums are supported. The enum must implement the FromStr trait. +- `String`: Maps to [ElicitResultContentPrimitive::String](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/enum.ElicitResultContentPrimitive.html). +- `bool`: Maps to [ElicitResultContentPrimitive::Boolean](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/enum.ElicitResultContentPrimitive.html). +- `i32`: Maps to [ElicitResultContentPrimitive::Integer](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/enum.ElicitResultContentPrimitive.html) (with bounds checking). +- `i64`: Maps to [ElicitResultContentPrimitive::Integer](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/enum.ElicitResultContentPrimitive.html). +- `Vec`: Maps to [ElicitResultContent::StringArray](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/enum.ElicitResultContent.html). - `Option`: Supported for any of the above types, mapping to `None` if the field is missing. -### Usage Example +### Usage Example (Form Mode) ```rust -use rust_mcp_sdk::macros::{mcp_elicit, JsonSchema}; -use rust_mcp_sdk::schema::RpcError; -use std::str::FromStr; - -// Simple enum with FromStr trait implemented -#[derive(JsonSchema, Debug)] -pub enum Colors { - #[json_schema(title = "Green Color")] - Green, - #[json_schema(title = "Red Color")] - Red, -} -impl FromStr for Colors { - type Err = RpcError; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "green" => Ok(Colors::Green), - "red" => Ok(Colors::Red), - _ => Err(RpcError::parse_error().with_message("Invalid color".to_string())), - } + #[mcp_elicit(message = "Please enter your info", mode = form)] + #[derive(JsonSchema)] + pub struct UserInfo { + #[json_schema(title = "Name", min_length = 5, max_length = 100)] + pub name: String, + #[json_schema(title = "Email", format = "email")] + pub email: Option, + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + #[json_schema(title = "Tags")] + pub tags: Vec, } -} -// A struct that could be used to send elicit request and get the input from the user -#[mcp_elicit(message = "Please enter your info")] -#[derive(JsonSchema)] -pub struct UserInfo { - #[json_schema( - title = "Name", - description = "The user's full name", - min_length = 5, - max_length = 100 - )] - pub name: String, - - /// Email address of the user - #[json_schema(title = "Email", format = "email")] - pub email: Option, - - /// The user's age in years - #[json_schema(title = "Age", minimum = 15, maximum = 125)] - pub age: i32, - - /// Is user a student? - #[json_schema(title = "Is student?", default = true)] - pub is_student: Option, - - /// User's favorite color - pub favorate_color: Colors, -} + // Sends a request to the client asking the user to provide input + let result: ElicitResult = server.request_elicitation(UserInfo::elicit_request_params()).await?; - // .... - // ....... - // ........... + // Convert result.content into a UserInfo instance + let user_info = UserInfo::from_elicit_result_content(result.content)?; + + println!("name: {}", user_info.name); + println!("age: {}", user_info.age); + println!("email: {}",user.email.clone().unwrap_or("not provider".into())); + println!("tags: {}", user_info.tags.join(",")); - // send a Elicit Request , ask for UserInfo data and convert the result back to a valid UserInfo instance +``` - let result: ElicitResult = server - .elicit_input(UserInfo::message(), UserInfo::requested_schema()) - .await?; - // Create a UserInfo instance using data provided by the user on the client side - let user_info = UserInfo::from_content_map(result.content)?; +### Usage Example (URL Mode) +```rust +#[mcp_elicit(message = "Complete the form", mode = url, url = "https://example.com/form")] + #[derive(JsonSchema)] + pub struct UserInfo { + #[json_schema(title = "Name", min_length = 5, max_length = 100)] + pub name: String, + #[json_schema(title = "Email", format = "email")] + pub email: Option, + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + #[json_schema(title = "Tags")] + pub tags: Vec, + } + let elicit_url = UserInfo::elicit_url_params("elicit_10".into()); + + // Sends a request to the client asking the user to provide input + let result: ElicitResult = server.request_elicitation(UserInfo::elicit_request_params()).await?; + + // Convert result.content into a UserInfo instance + let user_info = UserInfo::from_elicit_result_content(result.content)?; + + println!("name: {}", user_info.name); + println!("age: {}", user_info.age); + println!("email: {}", user_info.email.unwrap_or_default(); + println!("tags: {}", user_info.tags.join(",")); ``` - --- Check out [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk), a high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) takes care of the rest! diff --git a/crates/rust-mcp-macros/src/elicit.rs b/crates/rust-mcp-macros/src/elicit.rs new file mode 100644 index 0000000..1ffdea9 --- /dev/null +++ b/crates/rust-mcp-macros/src/elicit.rs @@ -0,0 +1,2 @@ +pub(crate) mod generator; +pub(crate) mod parser; diff --git a/crates/rust-mcp-macros/src/elicit/generator.rs b/crates/rust-mcp-macros/src/elicit/generator.rs new file mode 100644 index 0000000..fa72a4d --- /dev/null +++ b/crates/rust-mcp-macros/src/elicit/generator.rs @@ -0,0 +1,189 @@ +use crate::is_option; +use crate::is_vec_string; +use quote::quote; +use quote::ToTokens; +use syn::{ + punctuated::Punctuated, token::Comma, Expr, ExprLit, Ident, Lit, Meta, PathArguments, Token, + Type, +}; + +fn json_field_name(field: &syn::Field) -> String { + field + .attrs + .iter() + .filter(|a| a.path().is_ident("serde")) + .find_map(|attr| { + // Parse everything inside #[serde(...)] + let items = attr + .parse_args_with(Punctuated::::parse_terminated) + .ok()?; + + for item in items { + match item { + // Case 1: #[serde(rename = "field_name")] + Meta::NameValue(nv) if nv.path.is_ident("rename") => { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = nv.value + { + return Some(lit_str.value()); + } + } + + // Case 2: #[serde(rename(serialize = "a", deserialize = "b"))] + Meta::List(list) if list.path.is_ident("rename") => { + let inner_items = list + .parse_args_with(Punctuated::::parse_terminated) + .ok()?; + + for inner in inner_items { + if let Meta::NameValue(nv) = inner { + if nv.path.is_ident("serialize") || nv.path.is_ident("deserialize") + { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = nv.value + { + return Some(lit_str.value()); + } + } + } + } + } + + _ => {} + } + } + None + }) + .unwrap_or_else(|| field.ident.as_ref().unwrap().to_string()) +} + +// Form implementation generation +pub fn generate_from_impl( + fields: &Punctuated, + base: &proc_macro2::TokenStream, +) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + let mut assigns = Vec::new(); + let mut idents = Vec::new(); + + for field in fields { + let ident = field.ident.as_ref().unwrap(); + let key = json_field_name(field); + let ty = &field.ty; + + idents.push(ident); + + let block = if is_option(ty) { + let inner = get_option_inner(ty); + let (expected, pat, conv) = match_type(inner, &key, base); + quote! { + let #ident = match map.remove(#key) { + Some(#pat) => Some(#conv), + Some(other) => return Err(RpcError::parse_error().with_message(format!( + "Type mismatch for optional field '{}': expected {}, got {:?}", + #key, #expected, other + ))), + None => None, + }; + } + } else { + let (expected, pat, conv) = match_type(ty, &key, base); + quote! { + let #ident = match map.remove(#key) { + Some(#pat) => #conv, + Some(other) => return Err(RpcError::parse_error().with_message(format!( + "Type mismatch for required field '{}': expected {}, got {:?}", + #key, #expected, other + ))), + None => return Err(RpcError::parse_error().with_message(format!("Missing required field '{}'", #key))), + }; + } + }; + + assigns.push(block); + } + + (quote! { #(#assigns)* }, quote! { Self { #(#idents),* } }) +} + +pub fn get_option_inner(ty: &Type) -> &Type { + if let Type::Path(p) = ty { + if let Some(seg) = p.path.segments.last() { + if seg.ident == "Option" { + if let PathArguments::AngleBracketed(ref args) = seg.arguments { + if let Some(syn::GenericArgument::Type(inner)) = args.args.first() { + return inner; + } + } + } + } + } + panic!("Not Option") +} + +pub fn match_type( + ty: &Type, + key: &str, + base: &proc_macro2::TokenStream, +) -> (String, proc_macro2::TokenStream, proc_macro2::TokenStream) { + if is_vec_string(ty) { + return ( + "string array".into(), // expected + quote! { V::StringArray(v) }, + quote! { v }, + ); + }; + + match ty { + Type::Path(p) if p.path.is_ident("String") => ( + "string".into(), + quote! { V::Primitive(#base::ElicitResultContentPrimitive::String(v)) }, + quote! { v.clone() }, + ), + Type::Path(p) if p.path.is_ident("bool") => ( + "bool".into(), + quote! { V::Primitive(#base::ElicitResultContentPrimitive::Boolean(v)) }, + quote! { v }, + ), + Type::Path(p) if p.path.is_ident("i32") => ( + "i32".into(), + quote! { V::Primitive(#base::ElicitResultContentPrimitive::Integer(v)) }, + quote! { (v).try_into().map_err(|_| RpcError::parse_error().with_message(format!("i32 overflow in field '{}'", #key)))? }, + ), + Type::Path(p) if p.path.is_ident("i64") => ( + "i64".into(), + quote! { V::Primitive(#base::ElicitResultContentPrimitive::Integer(v)) }, + quote! { v }, + ), + _ => panic!("Unsupported type in mcp_elicit: {}", ty.to_token_stream()), + } +} + +pub fn generate_form_schema( + struct_name: &Ident, + base: &proc_macro2::TokenStream, +) -> proc_macro2::TokenStream { + quote! { + { + let json = #struct_name::json_schema(); + let properties = json.get("properties") + .and_then(|v| v.as_object()) + .into_iter() + .flatten() + .filter_map(|(k, v)| #base::PrimitiveSchemaDefinition::try_from(v.as_object()?).ok().map(|def| (k.clone(), def))) + .collect(); + + let required = json.get("required") + .and_then(|v| v.as_array()) + .into_iter() + .flatten() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + + #base::ElicitFormSchema::new(properties, required, None) + } + } +} diff --git a/crates/rust-mcp-macros/src/elicit/parser.rs b/crates/rust-mcp-macros/src/elicit/parser.rs new file mode 100644 index 0000000..fb1272a --- /dev/null +++ b/crates/rust-mcp-macros/src/elicit/parser.rs @@ -0,0 +1,84 @@ +use syn::{punctuated::Punctuated, Expr, ExprLit, Lit, LitStr, Meta, Token}; + +pub struct ElicitArgs { + pub message: LitStr, + pub mode: ElicitMode, +} + +pub enum ElicitMode { + Form, + Url { url: LitStr }, +} + +impl syn::parse::Parse for ElicitArgs { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut message = None; + let mut mode = ElicitMode::Form; // default + let mut url_lit: Option = None; + + let metas = Punctuated::::parse_terminated(input)?; + + // First pass + for meta in &metas { + if let Meta::NameValue(nv) = meta { + if let Some(ident) = nv.path.get_ident() { + if ident == "message" { + if let Expr::Lit(ExprLit { + lit: Lit::Str(s), .. + }) = &nv.value + { + message = Some(s.clone()); + } + } else if ident == "url" { + if let Expr::Lit(ExprLit { + lit: Lit::Str(s), .. + }) = &nv.value + { + url_lit = Some(s.clone()); + } + } + } + } + } + + // Second pass: handle `mode = url` or `mode = form` + for meta in &metas { + if let Meta::NameValue(nv) = meta { + if let Some(ident) = nv.path.get_ident() { + if ident == "mode" { + if let Expr::Path(path) = &nv.value { + if let Some(k) = path.path.get_ident() { + match k.to_string().as_str() { + "url" => { + let the_url = url_lit.clone().ok_or_else(|| { + syn::Error::new_spanned(nv, "when `mode = url`, you must also provide `url = \"https://...\"`") + })?; + mode = ElicitMode::Url { url: the_url }; + } + "form" => { + mode = ElicitMode::Form; + } + _ => { + return Err(syn::Error::new_spanned( + k, + "mode must be `form` or `url`", + )) + } + } + } + } else { + return Err(syn::Error::new_spanned( + &nv.value, + "mode must be `form` or `url`", + )); + } + } + } + } + } + + let message = message.unwrap_or_else(|| LitStr::new("", proc_macro2::Span::call_site())); + + Ok(Self { message, mode }) + } +} diff --git a/crates/rust-mcp-macros/src/lib.rs b/crates/rust-mcp-macros/src/lib.rs index 473792c..30bdc71 100644 --- a/crates/rust-mcp-macros/src/lib.rs +++ b/crates/rust-mcp-macros/src/lib.rs @@ -1,312 +1,17 @@ extern crate proc_macro; +mod elicit; +mod tool; mod utils; +use crate::elicit::generator::{generate_form_schema, generate_from_impl}; +use crate::elicit::parser::{ElicitArgs, ElicitMode}; +use crate::tool::generator::{generate_tool_tokens, ToolTokens}; +use crate::tool::parser::{IconThemeDsl, McpToolMacroAttributes}; use proc_macro::TokenStream; use quote::quote; -use syn::{ - parse::Parse, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Error, Expr, - ExprLit, Fields, GenericArgument, Lit, Meta, PathArguments, Token, Type, -}; -use utils::{is_option, renamed_field, type_to_json_schema}; - -/// Represents the attributes for the `mcp_tool` procedural macro. -/// -/// This struct parses and validates the attributes provided to the `mcp_tool` macro. -/// The `name` and `description` attributes are required and must not be empty strings. -/// -/// # Fields -/// * `name` - A string representing the tool's name (required). -/// * `description` - A string describing the tool (required). -/// * `meta` - An optional JSON string for metadata. -/// * `title` - An optional string for the tool's title. -/// * The following fields are available only with the `2025_03_26` feature and later: -/// * `destructive_hint` - Optional boolean for `ToolAnnotations::destructive_hint`. -/// * `idempotent_hint` - Optional boolean for `ToolAnnotations::idempotent_hint`. -/// * `open_world_hint` - Optional boolean for `ToolAnnotations::open_world_hint`. -/// * `read_only_hint` - Optional boolean for `ToolAnnotations::read_only_hint`. -/// -struct McpToolMacroAttributes { - name: Option, - description: Option, - #[cfg(feature = "2025_06_18")] - meta: Option, // Store raw JSON string instead of parsed Map - #[cfg(feature = "2025_06_18")] - title: Option, - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - destructive_hint: Option, - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - idempotent_hint: Option, - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - open_world_hint: Option, - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - read_only_hint: Option, -} - -use syn::parse::ParseStream; - -use crate::utils::{generate_enum_parse, is_enum}; - -struct ExprList { - exprs: Punctuated, -} - -impl Parse for ExprList { - fn parse(input: ParseStream) -> syn::Result { - Ok(ExprList { - exprs: Punctuated::parse_terminated(input)?, - }) - } -} - -impl Parse for McpToolMacroAttributes { - /// Parses the macro attributes from a `ParseStream`. - /// - /// This implementation extracts `name`, `description`, `meta`, and `title` from the attribute input. - /// The `name` and `description` must be provided as string literals and be non-empty. - /// The `meta` attribute must be a valid JSON object provided as a string literal, and `title` must be a string literal. - /// - /// # Errors - /// Returns a `syn::Error` if: - /// - The `name` attribute is missing or empty. - /// - The `description` attribute is missing or empty. - /// - The `meta` attribute is provided but is not a valid JSON object. - /// - The `title` attribute is provided but is not a string literal. - fn parse(attributes: syn::parse::ParseStream) -> syn::Result { - let mut instance = Self { - name: None, - description: None, - #[cfg(feature = "2025_06_18")] - meta: None, - #[cfg(feature = "2025_06_18")] - title: None, - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - destructive_hint: None, - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - idempotent_hint: None, - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - open_world_hint: None, - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - read_only_hint: None, - }; - - let meta_list: Punctuated = Punctuated::parse_terminated(attributes)?; - for meta in meta_list { - if let Meta::NameValue(meta_name_value) = meta { - let ident = meta_name_value.path.get_ident().unwrap(); - let ident_str = ident.to_string(); - - match ident_str.as_str() { - "name" | "description" => { - let value = match &meta_name_value.value { - Expr::Lit(ExprLit { - lit: Lit::Str(lit_str), - .. - }) => lit_str.value(), - Expr::Macro(expr_macro) => { - let mac = &expr_macro.mac; - if mac.path.is_ident("concat") { - let args: ExprList = syn::parse2(mac.tokens.clone())?; - let mut result = String::new(); - for expr in args.exprs { - if let Expr::Lit(ExprLit { - lit: Lit::Str(lit_str), - .. - }) = expr - { - result.push_str(&lit_str.value()); - } else { - return Err(Error::new_spanned( - expr, - "Only string literals are allowed inside concat!()", - )); - } - } - result - } else { - return Err(Error::new_spanned( - expr_macro, - "Only concat!(...) is supported here", - )); - } - } - _ => { - return Err(Error::new_spanned( - &meta_name_value.value, - "Expected a string literal or concat!(...)", - )); - } - }; - match ident_str.as_str() { - "name" => instance.name = Some(value), - "description" => instance.description = Some(value), - _ => {} - } - } - #[cfg(feature = "2025_06_18")] - "meta" => { - let value = match &meta_name_value.value { - Expr::Lit(ExprLit { - lit: Lit::Str(lit_str), - .. - }) => lit_str.value(), - _ => { - return Err(Error::new_spanned( - &meta_name_value.value, - "Expected a JSON object as a string literal", - )); - } - }; - // Validate that the string is a valid JSON object - let parsed: serde_json::Value = - serde_json::from_str(&value).map_err(|e| { - Error::new_spanned( - &meta_name_value.value, - format!("Expected a valid JSON object: {e}"), - ) - })?; - if !parsed.is_object() { - return Err(Error::new_spanned( - &meta_name_value.value, - "Expected a JSON object", - )); - } - instance.meta = Some(value); - } - #[cfg(feature = "2025_06_18")] - "title" => { - let value = match &meta_name_value.value { - Expr::Lit(ExprLit { - lit: Lit::Str(lit_str), - .. - }) => lit_str.value(), - _ => { - return Err(Error::new_spanned( - &meta_name_value.value, - "Expected a string literal", - )); - } - }; - instance.title = Some(value); - } - "destructive_hint" | "idempotent_hint" | "open_world_hint" - | "read_only_hint" => { - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - { - let value = match &meta_name_value.value { - Expr::Lit(ExprLit { - lit: Lit::Bool(lit_bool), - .. - }) => lit_bool.value, - _ => { - return Err(Error::new_spanned( - &meta_name_value.value, - "Expected a boolean literal", - )); - } - }; - - match ident_str.as_str() { - "destructive_hint" => instance.destructive_hint = Some(value), - "idempotent_hint" => instance.idempotent_hint = Some(value), - "open_world_hint" => instance.open_world_hint = Some(value), - "read_only_hint" => instance.read_only_hint = Some(value), - _ => {} - } - } - } - _ => {} - } - } - } - - // Validate presence and non-emptiness - if instance - .name - .as_ref() - .map(|s| s.trim().is_empty()) - .unwrap_or(true) - { - return Err(Error::new( - attributes.span(), - "The 'name' attribute is required and must not be empty.", - )); - } - if instance - .description - .as_ref() - .map(|s| s.trim().is_empty()) - .unwrap_or(true) - { - return Err(Error::new( - attributes.span(), - "The 'description' attribute is required and must not be empty.", - )); - } - - Ok(instance) - } -} - -struct McpElicitationAttributes { - message: Option, -} - -impl Parse for McpElicitationAttributes { - fn parse(attributes: syn::parse::ParseStream) -> syn::Result { - let mut instance = Self { message: None }; - let meta_list: Punctuated = Punctuated::parse_terminated(attributes)?; - for meta in meta_list { - if let Meta::NameValue(meta_name_value) = meta { - let ident = meta_name_value.path.get_ident().unwrap(); - let ident_str = ident.to_string(); - if ident_str.as_str() == "message" { - let value = match &meta_name_value.value { - Expr::Lit(ExprLit { - lit: Lit::Str(lit_str), - .. - }) => lit_str.value(), - Expr::Macro(expr_macro) => { - let mac = &expr_macro.mac; - if mac.path.is_ident("concat") { - let args: ExprList = syn::parse2(mac.tokens.clone())?; - let mut result = String::new(); - for expr in args.exprs { - if let Expr::Lit(ExprLit { - lit: Lit::Str(lit_str), - .. - }) = expr - { - result.push_str(&lit_str.value()); - } else { - return Err(Error::new_spanned( - expr, - "Only string literals are allowed inside concat!()", - )); - } - } - result - } else { - return Err(Error::new_spanned( - expr_macro, - "Only concat!(...) is supported here", - )); - } - } - _ => { - return Err(Error::new_spanned( - &meta_name_value.value, - "Expected a string literal or concat!(...)", - )); - } - }; - instance.message = Some(value) - } - } - } - Ok(instance) - } -} +use syn::{parse_macro_input, Data, DeriveInput, Fields}; +use utils::{base_crate, is_option, is_vec_string, renamed_field, type_to_json_schema}; /// A procedural macro attribute to generate rust_mcp_schema::Tool related utility methods for a struct. /// @@ -357,84 +62,22 @@ impl Parse for McpElicitationAttributes { pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let input_ident = &input.ident; - - // Conditionally select the path for Tool - let base_crate = if cfg!(feature = "sdk") { - quote! { rust_mcp_sdk::schema } - } else { - quote! { rust_mcp_schema } - }; - let macro_attributes = parse_macro_input!(attributes as McpToolMacroAttributes); - let tool_name = macro_attributes.name.unwrap_or_default(); - let tool_description = macro_attributes.description.unwrap_or_default(); - - #[cfg(not(feature = "2025_06_18"))] - let meta = quote! {}; - #[cfg(feature = "2025_06_18")] - let meta = macro_attributes.meta.map_or(quote! { meta: None, }, |m| { - quote! { meta: Some(serde_json::from_str(#m).expect("Failed to parse meta JSON")), } - }); - - #[cfg(not(feature = "2025_06_18"))] - let title = quote! {}; - #[cfg(feature = "2025_06_18")] - let title = macro_attributes.title.map_or( - quote! { title: None, }, - |t| quote! { title: Some(#t.to_string()), }, - ); - - #[cfg(not(feature = "2025_06_18"))] - let output_schema = quote! {}; - #[cfg(feature = "2025_06_18")] - let output_schema = quote! { output_schema: None,}; - - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - let some_annotations = macro_attributes.destructive_hint.is_some() - || macro_attributes.idempotent_hint.is_some() - || macro_attributes.open_world_hint.is_some() - || macro_attributes.read_only_hint.is_some(); - - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - let annotations = if some_annotations { - let destructive_hint = macro_attributes - .destructive_hint - .map_or(quote! {None}, |v| quote! {Some(#v)}); - - let idempotent_hint = macro_attributes - .idempotent_hint - .map_or(quote! {None}, |v| quote! {Some(#v)}); - let open_world_hint = macro_attributes - .open_world_hint - .map_or(quote! {None}, |v| quote! {Some(#v)}); - let read_only_hint = macro_attributes - .read_only_hint - .map_or(quote! {None}, |v| quote! {Some(#v)}); - quote! { - Some(#base_crate::ToolAnnotations { - destructive_hint: #destructive_hint, - idempotent_hint: #idempotent_hint, - open_world_hint: #open_world_hint, - read_only_hint: #read_only_hint, - title: None, - }) - } - } else { - quote! { None } - }; - - let annotations_token = { - #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))] - { - quote! { annotations: #annotations, } - } - #[cfg(not(any(feature = "2025_03_26", feature = "2025_06_18")))] - { - quote! {} - } - }; - + let ToolTokens { + base_crate, + tool_name, + tool_description, + meta, + title, + output_schema, + annotations, + execution, + icons, + } = generate_tool_tokens(macro_attributes); + + // TODO: add support for schema version to ToolInputSchema : + // it defaults to JSON Schema 2020-12 when no explicit $schema is provided. let tool_token = quote! { #base_crate::Tool { name: #tool_name.to_string(), @@ -442,8 +85,10 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { #output_schema #title #meta - #annotations_token - input_schema: #base_crate::ToolInputSchema::new(required, properties) + #annotations + #execution + #icons + input_schema: #base_crate::ToolInputSchema::new(required, properties, None) } }; @@ -454,6 +99,27 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { #tool_name.to_string() } + /// Returns a `CallToolRequestParams` initialized with the current tool's name. + /// + /// You can further customize the request by adding arguments or other attributes + /// using the builder pattern. For example: + /// + /// ```ignore + /// # use my_crate::{MyTool}; + /// let args = serde_json::Map::new(); + /// let task_meta = TaskMetadata{ttl: Some(200)} + /// + /// let params: CallToolRequestParams = MyTool::request_params() + /// .with_arguments(args) + /// .with_task(task_meta); + /// ``` + /// + /// # Returns + /// A `CallToolRequestParams` with the tool name set. + pub fn request_params() -> #base_crate::CallToolRequestParams { + #base_crate::CallToolRequestParams::new(#tool_name.to_string()) + } + /// Constructs and returns a `rust_mcp_schema::Tool` instance. /// /// The tool includes the name, description, input schema, meta, and title derived from @@ -503,300 +169,118 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { } #[proc_macro_attribute] -pub fn mcp_elicit(attributes: TokenStream, input: TokenStream) -> TokenStream { +pub fn mcp_elicit(args: TokenStream, input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); - let input_ident = &input.ident; - // Conditionally select the path - let base_crate = if cfg!(feature = "sdk") { - quote! { rust_mcp_sdk::schema } - } else { - quote! { rust_mcp_schema } + let fields = match &input.data { + Data::Struct(s) => match &s.fields { + Fields::Named(n) => &n.named, + _ => panic!("mcp_elicit only supports structs with named fields"), + }, + _ => panic!("mcp_elicit only supports structs"), }; - let macro_attributes = parse_macro_input!(attributes as McpElicitationAttributes); - let message = macro_attributes.message.unwrap_or_default(); + let struct_name = &input.ident; + let elicit_args = parse_macro_input!(args as ElicitArgs); - // Generate field assignments for from_content_map() - let field_assignments = match &input.data { - Data::Struct(data) => match &data.fields { - Fields::Named(fields) => { - let assignments = fields.named.iter().map(|field| { - let field_attrs = &field.attrs; - let field_ident = &field.ident; - let renamed_field = renamed_field(field_attrs); - let field_name = renamed_field.unwrap_or_else(|| field_ident.as_ref().unwrap().to_string()); - let field_type = &field.ty; + let base_crate = base_crate(); - let type_check = if is_option(field_type) { - // Extract inner type for Option - let inner_type = match field_type { - Type::Path(type_path) => { - let segment = type_path.path.segments.last().unwrap(); - if segment.ident == "Option" { - match &segment.arguments { - PathArguments::AngleBracketed(args) => { - match args.args.first().unwrap() { - GenericArgument::Type(ty) => ty, - _ => panic!("Expected type argument in Option"), - } - } - _ => panic!("Invalid Option type"), - } - } else { - panic!("Expected Option type"); - } - } - _ => panic!("Expected Option type"), - }; - // Determine the match arm based on the inner type at compile time - let (inner_type_ident, match_pattern, conversion) = match inner_type { - Type::Path(type_path) if type_path.path.is_ident("String") => ( - quote! { String }, - quote! { #base_crate::ElicitResultContentValue::String(s) }, - quote! { s.clone() } - ), - Type::Path(type_path) if type_path.path.is_ident("bool") => ( - quote! { bool }, - quote! { #base_crate::ElicitResultContentValue::Boolean(b) }, - quote! { *b } - ), - Type::Path(type_path) if type_path.path.is_ident("i32") => ( - quote! { i32 }, - quote! { #base_crate::ElicitResultContentValue::Integer(i) }, - quote! { - (*i).try_into().map_err(|_| #base_crate::RpcError::parse_error().with_message(format!( - "Invalid number for field '{}': value {} does not fit in i32", - #field_name, *i - )))? - } - ), - Type::Path(type_path) if type_path.path.is_ident("i64") => ( - quote! { i64 }, - quote! { #base_crate::ElicitResultContentValue::Integer(i) }, - quote! { *i } - ), - _ if is_enum(inner_type, &input) => { - let enum_parse = generate_enum_parse(inner_type, &field_name, &base_crate); - ( - quote! { #inner_type }, - quote! { #base_crate::ElicitResultContentValue::String(s) }, - quote! { #enum_parse } - ) - } - _ => panic!("Unsupported inner type for Option field: {}", quote! { #inner_type }), - }; - let inner_type_str = quote! { stringify!(#inner_type_ident) }; - quote! { - let #field_ident: Option<#inner_type_ident> = match content.as_ref().and_then(|map| map.get(#field_name)) { - Some(value) => { - match value { - #match_pattern => Some(#conversion), - _ => { - return Err(#base_crate::RpcError::parse_error().with_message(format!( - "Type mismatch for field '{}': expected {}, found {}", - #field_name, #inner_type_str, - match value { - #base_crate::ElicitResultContentValue::Boolean(_) => "boolean", - #base_crate::ElicitResultContentValue::String(_) => "string", - #base_crate::ElicitResultContentValue::Integer(_) => "integer", - } - ))); - } - } - } - None => None, - }; - } - } else { - // Determine the match arm based on the field type at compile time - let (field_type_ident, match_pattern, conversion) = match field_type { - Type::Path(type_path) if type_path.path.is_ident("String") => ( - quote! { String }, - quote! { #base_crate::ElicitResultContentValue::String(s) }, - quote! { s.clone() } - ), - Type::Path(type_path) if type_path.path.is_ident("bool") => ( - quote! { bool }, - quote! { #base_crate::ElicitResultContentValue::Boolean(b) }, - quote! { *b } - ), - Type::Path(type_path) if type_path.path.is_ident("i32") => ( - quote! { i32 }, - quote! { #base_crate::ElicitResultContentValue::Integer(i) }, - quote! { - (*i).try_into().map_err(|_| #base_crate::RpcError::parse_error().with_message(format!( - "Invalid number for field '{}': value {} does not fit in i32", - #field_name, *i - )))? - } - ), - Type::Path(type_path) if type_path.path.is_ident("i64") => ( - quote! { i64 }, - quote! { #base_crate::ElicitResultContentValue::Integer(i) }, - quote! { *i } - ), - _ if is_enum(field_type, &input) => { - let enum_parse = generate_enum_parse(field_type, &field_name, &base_crate); - ( - quote! { #field_type }, - quote! { #base_crate::ElicitResultContentValue::String(s) }, - quote! { #enum_parse } - ) - } - _ => panic!("Unsupported field type: {}", quote! { #field_type }), - }; - let type_str = quote! { stringify!(#field_type_ident) }; - quote! { - let #field_ident: #field_type_ident = match content.as_ref().and_then(|map| map.get(#field_name)) { - Some(value) => { - match value { - #match_pattern => #conversion, - _ => { - return Err(#base_crate::RpcError::parse_error().with_message(format!( - "Type mismatch for field '{}': expected {}, found {}", - #field_name, #type_str, - match value { - #base_crate::ElicitResultContentValue::Boolean(_) => "boolean", - #base_crate::ElicitResultContentValue::String(_) => "string", - #base_crate::ElicitResultContentValue::Integer(_) => "integer", - } - ))); - } - } - } - None => { - return Err(#base_crate::RpcError::parse_error().with_message(format!( - "Missing required field: {}", - #field_name - ))); - } - }; - } - }; + let message = &elicit_args.message; - type_check - }); + let impl_block = match elicit_args.mode { + ElicitMode::Form => { + let (from_content, init) = generate_from_impl(fields, &base_crate); + let schema = generate_form_schema(struct_name, &base_crate); - let field_idents = fields.named.iter().map(|field| &field.ident); + quote! { + impl #struct_name { + pub fn message() -> &'static str{ + #message + } - quote! { - #(#assignments)* + pub fn requested_schema() -> #base_crate::ElicitFormSchema { + #schema + } - Ok(Self { - #(#field_idents,)* - }) - } - } - _ => panic!("mcp_elicit macro only supports structs with named fields"), - }, - _ => panic!("mcp_elicit macro only supports structs"), - }; + pub fn elicit_mode()->&'static str{ + "form" + } - let output = quote! { - impl #input_ident { + pub fn elicit_form_params() -> #base_crate::ElicitRequestFormParams { + #base_crate::ElicitRequestFormParams::new( + Self::message().to_string(), + Self::requested_schema(), + None, + None, + ) + } - /// Returns the elicitation message defined in the `#[mcp_elicit(message = "...")]` attribute. - /// - /// This message is used to prompt the user or system for input when eliciting data for the struct. - /// If no message is provided in the attribute, an empty string is returned. - /// - /// # Returns - /// A `String` containing the elicitation message. - pub fn message()->String{ - #message.to_string() - } + pub fn elicit_request_params() -> #base_crate::ElicitRequestParams { + Self::elicit_form_params().into() + } - /// This method returns a `ElicitRequestedSchema` by retrieves the - /// struct's JSON schema (via the `JsonSchema` derive) and converting int into - /// a `ElicitRequestedSchema`. It extracts the `required` fields and - /// `properties` from the schema, mapping them to a `HashMap` of `PrimitiveSchemaDefinition` objects. - /// - /// # Returns - /// An `ElicitRequestedSchema` representing the schema of the struct. - /// - /// # Panics - /// Panics if the schema's properties cannot be converted to `PrimitiveSchemaDefinition` or if the schema - /// is malformed. - pub fn requested_schema() -> #base_crate::ElicitRequestedSchema { - let json_schema = &#input_ident::json_schema(); + pub fn from_elicit_result_content( + mut content: Option>, + ) -> Result { + use #base_crate::{ElicitResultContent as V, RpcError}; + let mut map = content.take().unwrap_or_default(); + #from_content + Ok(#init) + } - let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) { - Some(arr) => arr - .iter() - .filter_map(|item| item.as_str().map(String::from)) - .collect(), - None => Vec::new(), - }; + } + } + } + ElicitMode::Url { url } => { + let (from_content, init) = generate_from_impl(fields, &base_crate); - let properties: Option> = json_schema - .get("properties") - .and_then(|v| v.as_object()) // Safely extract "properties" as an object. - .map(|properties| { - properties - .iter() - .filter_map(|(key, value)| { - serde_json::to_value(value) - .ok() // If serialization fails, return None. - .and_then(|v| { - if let serde_json::Value::Object(obj) = v { - Some(obj) - } else { - None - } - }) - .map(|obj| (key.to_string(), #base_crate::PrimitiveSchemaDefinition::try_from(&obj))) - }) - .collect() - }); + quote! { + impl #struct_name { + pub fn message() -> &'static str { + #message + } - let properties = properties - .map(|map| { - map.into_iter() - .map(|(k, v)| v.map(|ok_v| (k, ok_v))) // flip Result inside tuple - .collect::, _>>() // collect only if all Ok - }) - .transpose() - .unwrap(); + pub fn url() -> &'static str { + #url + } - let properties = - properties.expect("Was not able to create a ElicitRequestedSchema"); + pub fn elicit_mode()->&'static str { + "url" + } - let requested_schema = #base_crate::ElicitRequestedSchema::new(properties, required); - requested_schema - } + pub fn elicit_url_params(elicitation_id:String) -> #base_crate::ElicitRequestUrlParams { + #base_crate::ElicitRequestUrlParams::new( + elicitation_id, + Self::message().to_string(), + Self::url().to_string(), + None, + None, + ) + } - /// Converts a map of field names and `ElicitResultContentValue` into an instance of the struct. - /// - /// This method parses the provided content map, matching field names to struct fields and converting - /// `ElicitResultContentValue` variants into the appropriate Rust types (e.g., `String`, `bool`, `i32`, - /// `i64`, or simple enums). It supports both required and optional fields (`Option`). - /// - /// # Parameters - /// - `content`: An optional `HashMap` mapping field names to `ElicitResultContentValue` values. - /// - /// # Returns - /// - `Ok(Self)` if the map is successfully parsed into the struct. - /// - `Err(RpcError)` if: - /// - A required field is missing. - /// - A value’s type does not match the expected field type. - /// - An integer value cannot be converted (e.g., `i64` to `i32` out of bounds). - /// - An enum value is invalid (e.g., string value does not match a enum variant name). - /// - /// # Errors - /// Returns `RpcError` with messages like: - /// - `"Missing required field: {}"` - /// - `"Type mismatch for field '{}': expected {}, found {}"` - /// - `"Invalid number for field '{}': value {} does not fit in i32"` - /// - `"Invalid enum value for field '{}': expected 'Yes' or 'No', found '{}'"`. - pub fn from_content_map(content: ::std::option::Option<::std::collections::HashMap<::std::string::String, #base_crate::ElicitResultContentValue>>) -> Result { - #field_assignments + pub fn elicit_request_params(elicitation_id:String) -> #base_crate::ElicitRequestParams { + Self::elicit_url_params(elicitation_id).into() + } + + pub fn from_elicit_result_content( + mut content: Option>, + ) -> Result { + use #base_crate::{ElicitResultContent as V, RpcError}; + let mut map = content.take().unwrap_or_default(); + #from_content + Ok(#init) + } + } } } + }; + + let expanded = quote! { #input + #impl_block }; - TokenStream::from(output) + TokenStream::from(expanded) } /// Derives a JSON Schema representation for a struct. @@ -1051,84 +535,3 @@ pub fn derive_json_schema(input: TokenStream) -> TokenStream { }; TokenStream::from(expanded) } - -#[cfg(test)] -mod tests { - use super::*; - use syn::parse_str; - #[test] - fn test_valid_macro_attributes() { - let input = r#"name = "test_tool", description = "A test tool.", meta = "{\"version\": \"1.0\"}", title = "Test Tool""#; - let parsed: McpToolMacroAttributes = parse_str(input).unwrap(); - - assert_eq!(parsed.name.unwrap(), "test_tool"); - assert_eq!(parsed.description.unwrap(), "A test tool."); - assert_eq!(parsed.meta.unwrap(), "{\"version\": \"1.0\"}"); - assert_eq!(parsed.title.unwrap(), "Test Tool"); - } - - #[test] - fn test_missing_name() { - let input = r#"description = "Only description""#; - let result: Result = parse_str(input); - assert!(result.is_err()); - assert_eq!( - result.err().unwrap().to_string(), - "The 'name' attribute is required and must not be empty." - ); - } - - #[test] - fn test_missing_description() { - let input = r#"name = "OnlyName""#; - let result: Result = parse_str(input); - assert!(result.is_err()); - assert_eq!( - result.err().unwrap().to_string(), - "The 'description' attribute is required and must not be empty." - ); - } - - #[test] - fn test_empty_name_field() { - let input = r#"name = "", description = "something""#; - let result: Result = parse_str(input); - assert!(result.is_err()); - assert_eq!( - result.err().unwrap().to_string(), - "The 'name' attribute is required and must not be empty." - ); - } - - #[test] - fn test_empty_description_field() { - let input = r#"name = "my-tool", description = """#; - let result: Result = parse_str(input); - assert!(result.is_err()); - assert_eq!( - result.err().unwrap().to_string(), - "The 'description' attribute is required and must not be empty." - ); - } - - #[test] - fn test_invalid_meta() { - let input = - r#"name = "test_tool", description = "A test tool.", meta = "not_a_json_object""#; - let result: Result = parse_str(input); - assert!(result.is_err()); - assert!(result - .err() - .unwrap() - .to_string() - .contains("Expected a valid JSON object")); - } - - #[test] - fn test_non_object_meta() { - let input = r#"name = "test_tool", description = "A test tool.", meta = "[1, 2, 3]""#; - let result: Result = parse_str(input); - assert!(result.is_err()); - assert_eq!(result.err().unwrap().to_string(), "Expected a JSON object"); - } -} diff --git a/crates/rust-mcp-macros/src/tool.rs b/crates/rust-mcp-macros/src/tool.rs new file mode 100644 index 0000000..1ffdea9 --- /dev/null +++ b/crates/rust-mcp-macros/src/tool.rs @@ -0,0 +1,2 @@ +pub(crate) mod generator; +pub(crate) mod parser; diff --git a/crates/rust-mcp-macros/src/tool/generator.rs b/crates/rust-mcp-macros/src/tool/generator.rs new file mode 100644 index 0000000..a7c1e82 --- /dev/null +++ b/crates/rust-mcp-macros/src/tool/generator.rs @@ -0,0 +1,184 @@ +use crate::tool::parser::ExecutionSupportDsl; +use crate::utils::base_crate; +use crate::IconThemeDsl; +use crate::McpToolMacroAttributes; +use proc_macro2::TokenStream; +use quote::quote; + +pub struct ToolTokens { + pub base_crate: TokenStream, + pub tool_name: String, + pub tool_description: String, + pub meta: TokenStream, + pub title: TokenStream, + pub output_schema: TokenStream, + pub annotations: TokenStream, + pub execution: TokenStream, + pub icons: TokenStream, +} + +pub fn generate_tool_tokens(macro_attributes: McpToolMacroAttributes) -> ToolTokens { + // Conditionally select the path for Tool + let base_crate = base_crate(); + let tool_name = macro_attributes.name.clone().unwrap_or_default(); + let tool_description = macro_attributes.description.clone().unwrap_or_default(); + + let title = macro_attributes.title.as_ref().map_or( + quote! { title: None, }, + |t| quote! { title: Some(#t.to_string()), }, + ); + + let meta = macro_attributes + .meta + .as_ref() + .map_or(quote! { meta: None, }, |m| { + quote! { meta: Some(serde_json::from_str(#m).expect("Failed to parse meta JSON")), } + }); + + //TODO: add support for output_schema + let output_schema = quote! { output_schema: None,}; + + let annotations = generate_annotations(&base_crate, ¯o_attributes); + let execution = generate_executions(&base_crate, ¯o_attributes); + let icons = generate_icons(&base_crate, ¯o_attributes); + + ToolTokens { + base_crate, + tool_name, + tool_description, + meta, + title, + output_schema, + annotations, + execution, + icons, + } +} + +fn generate_icons( + base_crate: &TokenStream, + macro_attributes: &McpToolMacroAttributes, +) -> TokenStream { + let mut icon_exprs = Vec::new(); + + if let Some(icons) = ¯o_attributes.icons { + for icon in icons { + let src = &icon.src; + let mime_type = icon + .mime_type + .as_ref() + .map(|s| quote! { Some(#s.to_string()) }) + .unwrap_or(quote! { None }); + let theme = icon + .theme + .as_ref() + .map(|t| match t { + IconThemeDsl::Light => quote! { Some(#base_crate::IconTheme::Light) }, + IconThemeDsl::Dark => quote! { Some(#base_crate::IconTheme::Dark) }, + }) + .unwrap_or(quote! { None }); + + // Build sizes: Vec + let sizes: Vec<_> = icon + .sizes + .as_ref() + .map(|arr| { + arr.elems + .iter() + .map(|elem| { + if let syn::Expr::Lit(expr_lit) = elem { + if let syn::Lit::Str(lit_str) = &expr_lit.lit { + let val = lit_str.value(); + return quote! { #val.to_string() }; + } + } + panic!("sizes must contain only string literals"); + }) + .collect::>() + }) + .unwrap_or_default(); + + let icon_expr = quote! { + #base_crate::Icon { + src: #src.to_string(), + mime_type: #mime_type, + sizes: vec![ #(#sizes),* ], + theme: #theme, + } + }; + icon_exprs.push(icon_expr); + } + } + + if icon_exprs.is_empty() { + quote! { icons: ::std::vec::Vec::new(), } + } else { + quote! { icons: vec![ #(#icon_exprs),* ], } + } +} + +fn generate_executions( + base_crate: &TokenStream, + macro_attributes: &McpToolMacroAttributes, +) -> TokenStream { + if let Some(exec) = macro_attributes.execution.as_ref() { + let task_support = match exec { + ExecutionSupportDsl::Forbidden => { + quote! { Some(#base_crate::ToolExecutionTaskSupport::Forbidden) } + } + ExecutionSupportDsl::Optional => { + quote! { Some(#base_crate::ToolExecutionTaskSupport::Optional) } + } + ExecutionSupportDsl::Required => { + quote! { Some(#base_crate::ToolExecutionTaskSupport::Required) } + } + }; + + quote! { + execution: Some(#base_crate::ToolExecution { + task_support: #task_support, + }), + } + } else { + quote! { execution: None, } + } +} + +fn generate_annotations( + base_crate: &TokenStream, + macro_attributes: &McpToolMacroAttributes, +) -> TokenStream { + let some_annotations = macro_attributes.destructive_hint.is_some() + || macro_attributes.idempotent_hint.is_some() + || macro_attributes.open_world_hint.is_some() + || macro_attributes.read_only_hint.is_some(); + + let annotations = if some_annotations { + let destructive_hint = macro_attributes + .destructive_hint + .map_or(quote! {None}, |v| quote! {Some(#v)}); + + let idempotent_hint = macro_attributes + .idempotent_hint + .map_or(quote! {None}, |v| quote! {Some(#v)}); + let open_world_hint = macro_attributes + .open_world_hint + .map_or(quote! {None}, |v| quote! {Some(#v)}); + let read_only_hint = macro_attributes + .read_only_hint + .map_or(quote! {None}, |v| quote! {Some(#v)}); + quote! { + Some(#base_crate::ToolAnnotations { + destructive_hint: #destructive_hint, + idempotent_hint: #idempotent_hint, + open_world_hint: #open_world_hint, + read_only_hint: #read_only_hint, + title: None, + }) + } + } else { + quote! { None } + }; + + quote! { annotations: #annotations, } +} diff --git a/crates/rust-mcp-macros/src/tool/parser.rs b/crates/rust-mcp-macros/src/tool/parser.rs new file mode 100644 index 0000000..fbcfee2 --- /dev/null +++ b/crates/rust-mcp-macros/src/tool/parser.rs @@ -0,0 +1,509 @@ +use quote::ToTokens; +use syn::parenthesized; +use syn::parse::ParseStream; +use syn::spanned::Spanned; +use syn::ExprArray; +use syn::{ + parse::Parse, punctuated::Punctuated, Error, Expr, ExprLit, Ident, Lit, LitStr, Meta, Token, +}; + +struct ExprList { + exprs: Punctuated, +} + +impl Parse for ExprList { + fn parse(input: ParseStream) -> syn::Result { + Ok(ExprList { + exprs: Punctuated::parse_terminated(input)?, + }) + } +} + +/// Represents the attributes for the `mcp_tool` procedural macro. +/// +/// This struct parses and validates the attributes provided to the `mcp_tool` macro. +/// The `name` and `description` attributes are required and must not be empty strings. +/// +/// # Fields +/// * `name` - A string representing the tool's name (required). +/// * `description` - A string describing the tool (required). +/// * `meta` - An optional JSON string for metadata. +/// * `title` - An optional string for the tool's title. +/// * The following fields are available only with the `2025_03_26` feature and later: +/// * `destructive_hint` - Optional boolean for `ToolAnnotations::destructive_hint`. +/// * `idempotent_hint` - Optional boolean for `ToolAnnotations::idempotent_hint`. +/// * `open_world_hint` - Optional boolean for `ToolAnnotations::open_world_hint`. +/// * `read_only_hint` - Optional boolean for `ToolAnnotations::read_only_hint`. +/// +pub(crate) struct McpToolMacroAttributes { + pub name: Option, + pub description: Option, + pub meta: Option, // Store raw JSON string instead of parsed Map + pub title: Option, + pub destructive_hint: Option, + pub idempotent_hint: Option, + pub open_world_hint: Option, + pub read_only_hint: Option, + pub execution: Option, + pub icons: Option>, +} + +pub(crate) enum ExecutionSupportDsl { + Forbidden, + Optional, + Required, +} + +pub(crate) struct IconDsl { + pub(crate) src: LitStr, + pub(crate) mime_type: Option, + pub(crate) sizes: Option, + pub(crate) theme: Option, +} + +pub(crate) enum IconThemeDsl { + Light, + Dark, +} + +pub(crate) struct IconField { + pub(crate) key: Ident, + pub(crate) _eq_token: Token![=], + pub(crate) value: syn::Expr, +} + +impl Parse for IconField { + fn parse(input: ParseStream) -> syn::Result { + Ok(IconField { + key: input.parse()?, + _eq_token: input.parse()?, + value: input.parse()?, + }) + } +} + +impl Parse for IconDsl { + fn parse(input: ParseStream) -> syn::Result { + let content; + parenthesized!(content in input); // parse ( ... ) + + let fields: Punctuated = + content.parse_terminated(IconField::parse, Token![,])?; + + let mut src = None; + let mut mime_type = None; + let mut sizes = None; + let mut theme = None; + + for field in fields { + let key_str = field.key.to_string(); + match key_str.as_str() { + "src" => { + if let syn::Expr::Lit(expr_lit) = field.value { + if let syn::Lit::Str(lit) = expr_lit.lit { + src = Some(lit); + } else { + return Err(syn::Error::new( + expr_lit.span(), + "expected string literal for src", + )); + } + } + } + "mime_type" => { + if let syn::Expr::Lit(expr_lit) = field.value { + if let syn::Lit::Str(lit) = expr_lit.lit { + mime_type = Some(lit); + } else { + return Err(syn::Error::new( + expr_lit.span(), + "expected string literal for mime_type", + )); + } + } + } + "sizes" => { + if let syn::Expr::Array(arr) = field.value { + // Validate that every element is a string literal. + for elem in &arr.elems { + match elem { + syn::Expr::Lit(expr_lit) => { + if let syn::Lit::Str(_) = &expr_lit.lit { + // ok + } else { + return Err(syn::Error::new( + expr_lit.span(), + "sizes array must contain string literals", + )); + } + } + _ => { + return Err(syn::Error::new( + elem.span(), + "sizes array must contain only string literals", + )); + } + } + } + + sizes = Some(arr); + } else { + return Err(syn::Error::new( + field.value.span(), + "expected array expression for sizes", + )); + } + } + "theme" => { + if let syn::Expr::Lit(expr_lit) = field.value { + if let syn::Lit::Str(lit) = expr_lit.lit { + theme = Some(match lit.value().as_str() { + "light" => IconThemeDsl::Light, + "dark" => IconThemeDsl::Dark, + _ => { + return Err(syn::Error::new( + lit.span(), + "theme must be \"light\" or \"dark\"", + )); + } + }); + } + } + } + _ => { + return Err(syn::Error::new( + field.key.span(), + "unexpected field in icon", + )) + } + } + } + + Ok(IconDsl { + src: src.ok_or_else(|| syn::Error::new(input.span(), "icon must have `src`"))?, + mime_type, + sizes, + theme, + }) + } +} + +impl Parse for IconThemeDsl { + fn parse(_input: ParseStream) -> syn::Result { + panic!("IconThemeDsl should be parsed inside IconDsl") + } +} + +impl Parse for McpToolMacroAttributes { + /// Parses the macro attributes from a `ParseStream`. + /// + /// This implementation extracts `name`, `description`, `meta`, and `title` from the attribute input. + /// The `name` and `description` must be provided as string literals and be non-empty. + /// The `meta` attribute must be a valid JSON object provided as a string literal, and `title` must be a string literal. + /// + /// # Errors + /// Returns a `syn::Error` if: + /// - The `name` attribute is missing or empty. + /// - The `description` attribute is missing or empty. + /// - The `meta` attribute is provided but is not a valid JSON object. + /// - The `title` attribute is provided but is not a string literal. + fn parse(attributes: syn::parse::ParseStream) -> syn::Result { + let mut instance = Self { + name: None, + description: None, + meta: None, + title: None, + destructive_hint: None, + idempotent_hint: None, + open_world_hint: None, + read_only_hint: None, + execution: None, + icons: None, + }; + + let meta_list: Punctuated = Punctuated::parse_terminated(attributes)?; + for meta in meta_list { + match meta { + Meta::NameValue(meta_name_value) => { + let ident = meta_name_value.path.get_ident().unwrap(); + let ident_str = ident.to_string(); + + match ident_str.as_str() { + "name" | "description" => { + let value = match &meta_name_value.value { + Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) => lit_str.value(), + Expr::Macro(expr_macro) => { + let mac = &expr_macro.mac; + if mac.path.is_ident("concat") { + let args: ExprList = syn::parse2(mac.tokens.clone())?; + let mut result = String::new(); + for expr in args.exprs { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = expr + { + result.push_str(&lit_str.value()); + } else { + return Err(Error::new_spanned( + expr, + "Only string literals are allowed inside concat!()", + )); + } + } + result + } else { + return Err(Error::new_spanned( + expr_macro, + "Only concat!(...) is supported here", + )); + } + } + _ => { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected a string literal or concat!(...)", + )); + } + }; + match ident_str.as_str() { + "name" => instance.name = Some(value), + "description" => instance.description = Some(value), + _ => {} + } + } + "meta" => { + let value = match &meta_name_value.value { + Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) => lit_str.value(), + _ => { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected a JSON object as a string literal", + )); + } + }; + // Validate that the string is a valid JSON object + let parsed: serde_json::Value = + serde_json::from_str(&value).map_err(|e| { + Error::new_spanned( + &meta_name_value.value, + format!("Expected a valid JSON object: {e}"), + ) + })?; + if !parsed.is_object() { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected a JSON object", + )); + } + instance.meta = Some(value); + } + "title" => { + let value = match &meta_name_value.value { + Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) => lit_str.value(), + _ => { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected a string literal", + )); + } + }; + instance.title = Some(value); + } + "destructive_hint" | "idempotent_hint" | "open_world_hint" + | "read_only_hint" => { + let value = match &meta_name_value.value { + Expr::Lit(ExprLit { + lit: Lit::Bool(lit_bool), + .. + }) => lit_bool.value, + _ => { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected a boolean literal", + )); + } + }; + + match ident_str.as_str() { + "destructive_hint" => instance.destructive_hint = Some(value), + "idempotent_hint" => instance.idempotent_hint = Some(value), + "open_world_hint" => instance.open_world_hint = Some(value), + "read_only_hint" => instance.read_only_hint = Some(value), + _ => {} + } + } + "icons" => { + // Check if the value is an array (Expr::Array) + if let Expr::Array(array_expr) = &meta_name_value.value { + let icon_list: Punctuated = array_expr + .elems + .iter() + .map(|elem| syn::parse2::(elem.to_token_stream())) + .collect::>()?; + instance.icons = Some(icon_list.into_iter().collect()); + } else { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected an array for the 'icons' attribute", + )); + } + } + other => { + eprintln!("other: {:?}", other) + } + } + } + Meta::List(meta_list) => { + let ident = meta_list.path.get_ident().unwrap(); + let ident_str = ident.to_string(); + + if ident_str == "execution" { + let nested = meta_list + .parse_args_with(Punctuated::::parse_terminated)?; + let mut task_support = None; + + for meta in nested { + if let Meta::NameValue(nv) = meta { + if nv.path.is_ident("task_support") { + if let Expr::Lit(ExprLit { + lit: Lit::Str(s), .. + }) = &nv.value + { + let value = s.value(); + task_support = Some(match value.as_str() { + "forbidden" => ExecutionSupportDsl::Forbidden, + "optional" => ExecutionSupportDsl::Optional, + "required" => ExecutionSupportDsl::Required, + _ => return Err(Error::new_spanned(&nv.value, "task_support must be one of: forbidden, optional, required")), + }); + } + } + } + } + + instance.execution = task_support; + } + } + _ => {} + } + } + + // Validate presence and non-emptiness + if instance + .name + .as_ref() + .map(|s| s.trim().is_empty()) + .unwrap_or(true) + { + return Err(Error::new( + attributes.span(), + "The 'name' attribute is required and must not be empty.", + )); + } + + if instance + .description + .as_ref() + .map(|s| s.trim().is_empty()) + .unwrap_or(true) + { + return Err(Error::new( + attributes.span(), + "The 'description' attribute is required and must not be empty.", + )); + } + + Ok(instance) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use syn::parse_str; + #[test] + fn test_valid_macro_attributes() { + let input = r#"name = "test_tool", description = "A test tool.", meta = "{\"version\": \"1.0\"}", title = "Test Tool""#; + let parsed: McpToolMacroAttributes = parse_str(input).unwrap(); + + assert_eq!(parsed.name.unwrap(), "test_tool"); + assert_eq!(parsed.description.unwrap(), "A test tool."); + assert_eq!(parsed.meta.unwrap(), "{\"version\": \"1.0\"}"); + assert_eq!(parsed.title.unwrap(), "Test Tool"); + } + + #[test] + fn test_missing_name() { + let input = r#"description = "Only description""#; + let result: Result = parse_str(input); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "The 'name' attribute is required and must not be empty." + ); + } + + #[test] + fn test_missing_description() { + let input = r#"name = "OnlyName""#; + let result: Result = parse_str(input); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "The 'description' attribute is required and must not be empty." + ); + } + + #[test] + fn test_empty_name_field() { + let input = r#"name = "", description = "something""#; + let result: Result = parse_str(input); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "The 'name' attribute is required and must not be empty." + ); + } + + #[test] + fn test_empty_description_field() { + let input = r#"name = "my-tool", description = """#; + let result: Result = parse_str(input); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "The 'description' attribute is required and must not be empty." + ); + } + + #[test] + fn test_invalid_meta() { + let input = + r#"name = "test_tool", description = "A test tool.", meta = "not_a_json_object""#; + let result: Result = parse_str(input); + assert!(result.is_err()); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Expected a valid JSON object")); + } + + #[test] + fn test_non_object_meta() { + let input = r#"name = "test_tool", description = "A test tool.", meta = "[1, 2, 3]""#; + let result: Result = parse_str(input); + assert!(result.is_err()); + assert_eq!(result.err().unwrap().to_string(), "Expected a JSON object"); + } +} diff --git a/crates/rust-mcp-macros/src/utils.rs b/crates/rust-mcp-macros/src/utils.rs index 71d3de3..ca1c4c2 100644 --- a/crates/rust-mcp-macros/src/utils.rs +++ b/crates/rust-mcp-macros/src/utils.rs @@ -1,9 +1,19 @@ +use proc_macro2::TokenStream; use quote::quote; use syn::{ - punctuated::Punctuated, token, Attribute, DeriveInput, Lit, LitInt, LitStr, Path, - PathArguments, Type, + punctuated::Punctuated, token, Attribute, DeriveInput, GenericArgument, Lit, LitInt, LitStr, + Path, PathArguments, Type, TypePath, }; +pub fn base_crate() -> TokenStream { + // Conditionally select the path for Tool + if cfg!(feature = "sdk") { + quote! { rust_mcp_sdk::schema } + } else { + quote! { rust_mcp_schema } + } +} + // Check if a type is an Option pub fn is_option(ty: &Type) -> bool { if let Type::Path(type_path) = ty { @@ -84,6 +94,7 @@ pub fn might_be_struct(ty: &Type) -> bool { false } +#[allow(unused)] // Helper to check if a type is an enum pub fn is_enum(ty: &Type, _input: &DeriveInput) -> bool { if let Type::Path(type_path) = ty { @@ -108,6 +119,7 @@ pub fn is_enum(ty: &Type, _input: &DeriveInput) -> bool { } } +#[allow(unused)] // Helper to generate enum parsing code pub fn generate_enum_parse( field_type: &Type, @@ -457,6 +469,38 @@ pub fn has_derive(attrs: &[Attribute], trait_name: &str) -> bool { }) } +pub fn is_vec_string(ty: &Type) -> bool { + let Type::Path(TypePath { path, .. }) = ty else { + return false; + }; + + // Get last segment: e.g., `Vec` + let Some(seg) = path.segments.last() else { + return false; + }; + + // Must be `Vec` + if seg.ident != "Vec" { + return false; + } + + // Must have angle-bracketed args: + let PathArguments::AngleBracketed(args) = &seg.arguments else { + return false; + }; + + // Must contain exactly one type param + if args.args.len() != 1 { + return false; + } + + // Check that the argument is `String` + match args.args.first().unwrap() { + GenericArgument::Type(Type::Path(tp)) => tp.path.is_ident("String"), + _ => false, + } +} + pub fn renamed_field(attrs: &[Attribute]) -> Option { let mut renamed = None; diff --git a/crates/rust-mcp-macros/tests/common/common.rs b/crates/rust-mcp-macros/tests/common/common.rs index 1133d64..e754cdb 100644 --- a/crates/rust-mcp-macros/tests/common/common.rs +++ b/crates/rust-mcp-macros/tests/common/common.rs @@ -48,30 +48,3 @@ impl FromStr for Colors { } } } - -#[mcp_elicit(message = "Please enter your info")] -#[derive(JsonSchema)] -pub struct UserInfo { - #[json_schema( - title = "Name", - description = "The user's full name", - min_length = 5, - max_length = 100 - )] - pub name: String, - - /// Email address of the user - #[json_schema(title = "Email", format = "email")] - pub email: Option, - - /// The user's age in years - #[json_schema(title = "Age", minimum = 15, maximum = 125)] - pub age: i32, - - /// Is user a student? - #[json_schema(title = "Is student?", default = true)] - pub is_student: Option, - - /// User's favorite color - pub favorate_color: Colors, -} diff --git a/crates/rust-mcp-macros/tests/macro_test.rs b/crates/rust-mcp-macros/tests/macro_test.rs deleted file mode 100644 index 4b6c926..0000000 --- a/crates/rust-mcp-macros/tests/macro_test.rs +++ /dev/null @@ -1,274 +0,0 @@ -#[macro_use] -extern crate rust_mcp_macros; - -use std::collections::HashMap; - -use common::EditOperation; -use rust_mcp_schema::{ - BooleanSchema, ElicitRequestedSchema, ElicitResultContentValue, EnumSchema, NumberSchema, - PrimitiveSchemaDefinition, StringSchema, StringSchemaFormat, -}; -use serde_json::json; - -use crate::common::{Colors, UserInfo}; - -#[path = "common/common.rs"] -pub mod common; - -#[test] -fn test_rename() { - let schema = EditOperation::json_schema(); - - assert_eq!(schema.len(), 3); - - assert!(schema.contains_key("properties")); - assert!(schema.contains_key("required")); - - assert!(schema.contains_key("type")); - assert_eq!(schema.get("type").unwrap(), "object"); - - let required: Vec<_> = schema - .get("required") - .unwrap() - .as_array() - .unwrap() - .iter() - .filter_map(|v| v.as_str()) - .collect(); - - assert_eq!(required.len(), 2); - assert!(required.contains(&"oldText")); - assert!(required.contains(&"newText")); - - let properties = schema.get("properties").unwrap().as_object().unwrap(); - assert_eq!(properties.len(), 2); -} - -#[test] -fn test_attributes() { - #[derive(JsonSchema)] - struct User { - /// This is a fallback description from doc comment. - pub id: i32, - - #[json_schema( - title = "User Name", - description = "The user's full name (overrides doc)", - min_length = 1, - max_length = 100 - )] - pub name: String, - - #[json_schema( - title = "User Email", - format = "email", - min_length = 5, - max_length = 255 - )] - pub email: Option, - - #[json_schema( - title = "Tags", - description = "List of tags", - min_length = 0, - max_length = 10 - )] - pub tags: Vec, - } - - let schema = User::json_schema(); - let expected = json!({ - "type": "object", - "properties": { - "id": { - "type": "integer", - "description": "This is a fallback description from doc comment." - }, - "name": { - "type": "string", - "title": "User Name", - "description": "The user's full name (overrides doc)", - "minLength": 1, - "maxLength": 100 - }, - "email": { - "type": "string", - "title": "User Email", - "format": "email", - "minLength": 5, - "maxLength": 255, - "nullable": true - }, - "tags": { - "type": "array", - "items": { - "type": "string", - }, - "title": "Tags", - "description": "List of tags", - "minItems": 0, - "maxItems": 10 - } - }, - "required": ["id", "name", "tags"] - }); - - // Convert expected_value from serde_json::Value to serde_json::Map - let expected: serde_json::Map = - expected.as_object().expect("Expected JSON object").clone(); - - assert_eq!(schema, expected); -} - -#[test] -fn test_elicit_macro() { - assert_eq!(UserInfo::message(), "Please enter your info"); - - let requested_schema: ElicitRequestedSchema = UserInfo::requested_schema(); - assert_eq!( - requested_schema.required, - vec!["name", "age", "favorate_color"] - ); - - assert!(matches!( - requested_schema.properties.get("is_student").unwrap(), - PrimitiveSchemaDefinition::BooleanSchema(BooleanSchema { - default, - description, - title, - .. - }) - if - description.as_ref().unwrap() == "Is user a student?" && - title.as_ref().unwrap() == "Is student?" && - matches!(default, Some(true)) - - )); - - assert!(matches!( - requested_schema.properties.get("favorate_color").unwrap(), - PrimitiveSchemaDefinition::EnumSchema(EnumSchema { - description, - enum_, - enum_names, - title, - .. - }) - if description.as_ref().unwrap() == "User's favorite color" && - title.is_none() && - enum_.len()==2 && enum_.iter().all(|s| ["Green", "Red"].contains(&s.as_str())) && - enum_names.len()==2 && enum_names.iter().all(|s| ["Green Color", "Red Color"].contains(&s.as_str())) - )); - - assert!(matches!( - requested_schema.properties.get("age").unwrap(), - PrimitiveSchemaDefinition::NumberSchema(NumberSchema { - description, - maximum, - minimum, - title, - type_ - }) - if - description.as_ref().unwrap() == "The user's age in years" && - maximum.unwrap() == 125 && minimum.unwrap() == 15 && title.as_ref().unwrap() == "Age" - )); - - assert!(matches!( - requested_schema.properties.get("name").unwrap(), - PrimitiveSchemaDefinition::StringSchema(StringSchema { - description, - format, - max_length, - min_length, - title, - .. - }) - if format.is_none() && - description.as_ref().unwrap() == "The user's full name" && - max_length.unwrap() == 100 && min_length.unwrap() == 5 && title.as_ref().unwrap() == "Name" - )); - - assert!(matches!( - requested_schema.properties.get("email").unwrap(), - PrimitiveSchemaDefinition::StringSchema(StringSchema { - description, - format, - max_length, - min_length, - title, - .. - }) if matches!(format.unwrap(), StringSchemaFormat::Email) && - description.as_ref().unwrap() == "Email address of the user" && - max_length.is_none() && min_length.is_none() && title.as_ref().unwrap() == "Email" - )); - - let json_schema = &UserInfo::json_schema(); - - let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) { - Some(arr) => arr - .iter() - .filter_map(|item| item.as_str().map(String::from)) - .collect(), - None => Vec::new(), - }; - - let properties: Option> = json_schema - .get("properties") - .and_then(|v| v.as_object()) // Safely extract "properties" as an object. - .map(|properties| { - properties - .iter() - .filter_map(|(key, value)| { - serde_json::to_value(value) - .ok() // If serialization fails, return None. - .and_then(|v| { - if let serde_json::Value::Object(obj) = v { - Some(obj) - } else { - None - } - }) - .map(|obj| (key.to_string(), PrimitiveSchemaDefinition::try_from(&obj))) - }) - .collect() - }); - - let properties = properties - .map(|map| { - map.into_iter() - .map(|(k, v)| v.map(|ok_v| (k, ok_v))) // flip Result inside tuple - .collect::, _>>() // collect only if all Ok - }) - .transpose() - .unwrap(); - - let properties = properties.expect("Was not able to create a ElicitRequestedSchema"); - - ElicitRequestedSchema::new(properties, required); -} - -#[test] -fn test_from_content_map() { - let mut content: ::std::collections::HashMap<::std::string::String, ElicitResultContentValue> = - HashMap::new(); - - content.extend([ - ( - "name".to_string(), - ElicitResultContentValue::String("Ali".to_string()), - ), - ( - "favorate_color".to_string(), - ElicitResultContentValue::String("Green".to_string()), - ), - ("age".to_string(), ElicitResultContentValue::Integer(15)), - ( - "is_student".to_string(), - ElicitResultContentValue::Boolean(false), - ), - ]); - - let u: UserInfo = UserInfo::from_content_map(Some(content)).unwrap(); - assert!(matches!(u.favorate_color, Colors::Green)); -} diff --git a/crates/rust-mcp-macros/tests/test_mcp_elicit.rs b/crates/rust-mcp-macros/tests/test_mcp_elicit.rs new file mode 100644 index 0000000..dba92d9 --- /dev/null +++ b/crates/rust-mcp-macros/tests/test_mcp_elicit.rs @@ -0,0 +1,403 @@ +use rust_mcp_macros::{mcp_elicit, JsonSchema}; +use rust_mcp_schema::{ + ElicitRequestFormParams, ElicitRequestParams, ElicitRequestUrlParams, ElicitResultContent, + RpcError, +}; +use std::collections::HashMap; + +#[test] +fn test_form_basic_conversion() { + // Form elicit basic + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Please enter your name and age", mode=form)] + pub struct BasicUser { + pub name: String, + pub age: Option, + pub expertise: Vec, + } + assert_eq!(BasicUser::message(), "Please enter your name and age"); + let mut content: std::collections::HashMap = HashMap::new(); + content.insert( + "name".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::String( + "Ali".to_string(), + )), + ); + content.insert( + "age".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::Integer(21)), + ); + content.insert( + "expertise".to_string(), + ElicitResultContent::StringArray(vec!["Rust".to_string(), "C++".to_string()]), + ); + + let user: BasicUser = BasicUser::from_elicit_result_content(Some(content)).unwrap(); + assert_eq!(user.name, "Ali"); + assert_eq!(user.age, Some(21)); + assert_eq!(user.expertise, vec!["Rust".to_string(), "C++".to_string()]); + + let req = BasicUser::elicit_request_params(); + match req { + ElicitRequestParams::FormParams(form) => { + assert_eq!(form.message, "Please enter your name and age"); + assert!(form.requested_schema.properties.contains_key("name")); + assert!(form.requested_schema.properties.contains_key("age")); + assert_eq!(form.requested_schema.required, vec!["name", "expertise"]); // age is optional + assert!(form.meta.is_none()); + assert_eq!(form.mode().as_ref().unwrap(), "form"); + } + _ => panic!("Expected FormParams"), + } +} + +#[test] +fn test_url_basic_conversion() { + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Please enter your name and age", mode=url, url="https://github.com/rust-mcp-stack/rust-mcp-sdk")] + pub struct InfoFromUrl { + pub name: String, + pub age: Option, + pub expertise: Vec, + } + + assert_eq!(InfoFromUrl::message(), "Please enter your name and age"); + assert_eq!( + InfoFromUrl::url(), + "https://github.com/rust-mcp-stack/rust-mcp-sdk" + ); + + let mut content: std::collections::HashMap = HashMap::new(); + content.insert( + "name".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::String( + "Ali".to_string(), + )), + ); + content.insert( + "age".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::Integer(21)), + ); + content.insert( + "expertise".to_string(), + ElicitResultContent::StringArray(vec!["Rust".to_string(), "C++".to_string()]), + ); + + let user: InfoFromUrl = InfoFromUrl::from_elicit_result_content(Some(content)).unwrap(); + assert_eq!(user.name, "Ali"); + assert_eq!(user.age, Some(21)); + assert_eq!(user.expertise, vec!["Rust".to_string(), "C++".to_string()]); + let req = InfoFromUrl::elicit_request_params("elicit_id".to_string()); + match req { + ElicitRequestParams::UrlParams(params) => { + assert_eq!(params.message, "Please enter your name and age"); + assert!(params.meta.is_none()); + assert!(params.task.is_none()); + assert_eq!(params.mode(), "url"); + } + _ => panic!("Expected UrlParams"), + } +} + +#[test] +fn test_missing_required_field_returns_error() { + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Enter user info", mode = form)] + pub struct RequiredFields { + pub name: String, + pub email: String, + pub tags: Vec, + } + + let mut content = HashMap::new(); + content.insert( + "name".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::String( + "Alice".to_string(), + )), + ); + // Missing 'email' and 'tags' - both required + + let result = RequiredFields::from_elicit_result_content(Some(content)); + assert!(result.is_err()); +} + +#[test] +fn test_extra_unknown_field_is_ignored() { + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Test", mode = form)] + pub struct StrictStruct { + pub name: String, + } + + let mut content = HashMap::new(); + content.insert( + "name".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::String( + "Bob".to_string(), + )), + ); + content.insert( + "unknown_field".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::String( + "ignored".to_string(), + )), + ); + + let user = StrictStruct::from_elicit_result_content(Some(content)).unwrap(); + assert_eq!(user.name, "Bob"); + // unknown_field is silently ignored - correct behavior +} + +#[test] +fn test_type_mismatch_returns_error() { + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Bad type", mode = form)] + pub struct TypeSensitive { + pub age: i32, + pub active: bool, + } + + let mut content = HashMap::new(); + content.insert( + "age".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::String( + "not_a_number".to_string(), + )), + ); + content.insert( + "active".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::Integer(1)), + ); + + let result = TypeSensitive::from_elicit_result_content(Some(content)); + assert!(result.is_err()); +} + +#[test] +fn test_empty_string_array_when_missing_optional_vec() { + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Optional vec", mode = form)] + pub struct OptionalVec { + pub name: String, + pub hobbies: Option>, + } + + let mut content = HashMap::new(); + content.insert( + "name".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::String( + "Charlie".to_string(), + )), + ); + // hobbies omitted entirely + + let user = OptionalVec::from_elicit_result_content(Some(content)).unwrap(); + assert_eq!(user.name, "Charlie"); + assert_eq!(user.hobbies, None); +} + +#[test] +fn test_empty_content_map_becomes_default_values() { + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Defaults", mode = form)] + pub struct WithOptionals { + pub name: String, + pub age: i64, + pub is_admin: bool, + } + + let result = WithOptionals::from_elicit_result_content(None); + assert!(result.is_err()); + + let result_empty = WithOptionals::from_elicit_result_content(Some(HashMap::new())); + assert!(result_empty.is_err()); +} + +#[test] +fn test_boolean_handling() { + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Bool test", mode = form)] + pub struct BoolStruct { + pub is_active: bool, + pub has_permission: Option, + } + + let mut content = HashMap::new(); + content.insert( + "is_active".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::Boolean( + true, + )), + ); + content.insert( + "has_permission".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::Boolean( + false, + )), + ); + + let s = BoolStruct::from_elicit_result_content(Some(content)).unwrap(); + assert!(s.is_active); + assert_eq!(s.has_permission, Some(false)); +} + +#[test] +fn test_numeric_types_variations() { + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Numbers", mode = form)] + pub struct Numbers { + pub count: i32, + pub ratio: Option, + } + + let mut content = HashMap::new(); + content.insert( + "count".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::Integer(42)), + ); + + let n = Numbers::from_elicit_result_content(Some(content)).unwrap(); + assert_eq!(n.count, 42); + assert_eq!(n.ratio, None); +} + +#[test] +fn test_url_mode_with_elicitation_id() { + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Go to this link", mode = url, url = "https://example.com/form/123")] + pub struct ExternalForm { + pub token: String, + } + + let params = ExternalForm::elicit_url_params("elicit-999".to_string()); + assert_eq!(params.elicitation_id, "elicit-999"); + assert_eq!(params.message, "Go to this link"); + assert_eq!(params.url, "https://example.com/form/123"); + + let req_params = ExternalForm::elicit_request_params("elicit-999".to_string()); + match req_params { + ElicitRequestParams::UrlParams(p) => { + assert_eq!(p.elicitation_id, "elicit-999"); + } + _ => panic!("Wrong variant"), + } +} +#[test] +fn test_form_and_url_share_same_from_elicit_result_content_logic() { + // This ensures both modes reuse the same parsing logic (good!) + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Same parsing", mode = form)] + pub struct FormSame { + pub x: String, + } + + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Same parsing", mode = url, url = "http://localhost")] + pub struct UrlSame { + pub x: String, + } + + let mut content = HashMap::new(); + content.insert( + "x".to_string(), + ElicitResultContent::Primitive(rust_mcp_schema::ElicitResultContentPrimitive::String( + "shared".to_string(), + )), + ); + + let f = FormSame::from_elicit_result_content(Some(content.clone())).unwrap(); + let u = UrlSame::from_elicit_result_content(Some(content)).unwrap(); + + assert_eq!(f.x, "shared"); + assert_eq!(u.x, "shared"); +} + +#[test] +fn test_string_array_empty_input_becomes_empty_vec() { + #[derive(Debug, Clone, JsonSchema)] + #[mcp_elicit(message = "Empty array", mode = form)] + pub struct EmptyArray { + pub items: Vec, + } + + let mut content = HashMap::new(); + content.insert( + "items".to_string(), + ElicitResultContent::StringArray(vec![]), + ); + + let s = EmptyArray::from_elicit_result_content(Some(content)).unwrap(); + assert!(s.items.is_empty()); +} + +#[test] +fn readme_example_elicitation() { + use rust_mcp_macros::{mcp_elicit, JsonSchema}; + use rust_mcp_schema::{ElicitRequestParams, ElicitResultContent}; + use std::collections::HashMap; + + #[mcp_elicit(message = "Please enter your info", mode = form)] + #[derive(JsonSchema)] + pub struct UserInfo { + #[json_schema(title = "Name", min_length = 5, max_length = 100)] + pub name: String, + #[json_schema(title = "Email", format = "email")] + pub email: Option, + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + #[json_schema(title = "Tags")] + pub tags: Vec, + } + + let params = UserInfo::elicit_request_params(); + if let ElicitRequestParams::FormParams(form) = params { + assert_eq!(form.message, "Please enter your info"); + } + + // Simulate user input + let mut content: HashMap = HashMap::new(); + content.insert("name".to_string(), "Alice".into()); + content.insert("email".to_string(), "alice@Borderland.com".into()); + content.insert("age".to_string(), 25.into()); + content.insert("tags".to_string(), vec!["rust", "c++"].into()); + + let user = UserInfo::from_elicit_result_content(Some(content)).unwrap(); + assert_eq!(user.name, "Alice"); + assert_eq!(user.age, 25); + assert_eq!(user.tags, vec!["rust", "c++"]); + assert_eq!(user.email.unwrap(), "alice@Borderland.com"); +} + +#[test] +fn readme_example_elicitation_url() { + #[mcp_elicit(message = "Complete the form", mode = url, url = "https://example.com/form")] + #[derive(JsonSchema)] + pub struct UserInfo { + #[json_schema(title = "Name", min_length = 5, max_length = 100)] + pub name: String, + #[json_schema(title = "Email", format = "email")] + pub email: Option, + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + #[json_schema(title = "Tags")] + pub tags: Vec, + } + + let elicit_url = UserInfo::elicit_url_params("elicit_10".into()); + + assert_eq!(elicit_url.message, "Complete the form"); + + // Simulate user input + let mut content: HashMap = HashMap::new(); + content.insert("name".to_string(), "Alice".into()); + content.insert("email".to_string(), "alice@Borderland.com".into()); + content.insert("age".to_string(), 25.into()); + content.insert("tags".to_string(), vec!["rust", "c++"].into()); + + let user = UserInfo::from_elicit_result_content(Some(content)).unwrap(); + assert_eq!(user.name, "Alice"); + assert_eq!(user.age, 25); + assert_eq!(user.tags, vec!["rust", "c++"]); + assert_eq!(user.email.unwrap(), "alice@Borderland.com"); +} diff --git a/crates/rust-mcp-macros/tests/test_mcp_tool.rs b/crates/rust-mcp-macros/tests/test_mcp_tool.rs new file mode 100644 index 0000000..1f4a213 --- /dev/null +++ b/crates/rust-mcp-macros/tests/test_mcp_tool.rs @@ -0,0 +1,477 @@ +#[macro_use] +extern crate rust_mcp_macros; +use common::EditOperation; +use rust_mcp_macros::{mcp_elicit, JsonSchema}; +use rust_mcp_schema::{ + CallToolRequestParams, ElicitRequestFormParams, ElicitRequestParams, ElicitResultContent, + ElicitResultContentPrimitive, RpcError, +}; +use rust_mcp_schema::{IconTheme, Tool, ToolExecutionTaskSupport}; +use serde_json::json; + +#[path = "common/common.rs"] +pub mod common; + +#[test] +fn test_rename() { + let schema = EditOperation::json_schema(); + + assert_eq!(schema.len(), 3); + + assert!(schema.contains_key("properties")); + assert!(schema.contains_key("required")); + + assert!(schema.contains_key("type")); + assert_eq!(schema.get("type").unwrap(), "object"); + + let required: Vec<_> = schema + .get("required") + .unwrap() + .as_array() + .unwrap() + .iter() + .filter_map(|v| v.as_str()) + .collect(); + + assert_eq!(required.len(), 2); + assert!(required.contains(&"oldText")); + assert!(required.contains(&"newText")); + + let properties = schema.get("properties").unwrap().as_object().unwrap(); + assert_eq!(properties.len(), 2); +} + +#[test] +fn test_attributes() { + #[derive(JsonSchema)] + struct User { + /// This is a fallback description from doc comment. + pub id: i32, + + #[json_schema( + title = "User Name", + description = "The user's full name (overrides doc)", + min_length = 1, + max_length = 100 + )] + pub name: String, + + #[json_schema( + title = "User Email", + format = "email", + min_length = 5, + max_length = 255 + )] + pub email: Option, + + #[json_schema( + title = "Tags", + description = "List of tags", + min_length = 0, + max_length = 10 + )] + pub tags: Vec, + } + + let schema = User::json_schema(); + let expected = json!({ + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "This is a fallback description from doc comment." + }, + "name": { + "type": "string", + "title": "User Name", + "description": "The user's full name (overrides doc)", + "minLength": 1, + "maxLength": 100 + }, + "email": { + "type": "string", + "title": "User Email", + "format": "email", + "minLength": 5, + "maxLength": 255, + "nullable": true + }, + "tags": { + "type": "array", + "items": { + "type": "string", + }, + "title": "Tags", + "description": "List of tags", + "minItems": 0, + "maxItems": 10 + } + }, + "required": ["id", "name", "tags"] + }); + + // Convert expected_value from serde_json::Value to serde_json::Map + let expected: serde_json::Map = + expected.as_object().expect("Expected JSON object").clone(); + + assert_eq!(schema, expected); +} + +#[test] +fn basic_tool_name_and_description() { + #[derive(JsonSchema)] + #[mcp_tool(name = "echo", description = "Repeats input")] + struct Echo { + message: String, + } + + let tool = Echo::tool(); + assert_eq!(tool.name, "echo"); + assert_eq!(tool.description.unwrap(), "Repeats input"); +} + +#[test] +fn meta_json_is_parsed_correctly() { + #[derive(JsonSchema)] + #[mcp_tool( + name = "weather", + description = "Get weather", + meta = r#"{"category": "utility", "version": "1.0"}"# + )] + struct Weather { + location: String, + } + + let tool = Weather::tool(); + let meta = tool.meta.as_ref().unwrap(); + assert_eq!(meta["category"], "utility"); + assert_eq!(meta["version"], "1.0"); +} + +#[test] +fn title_is_set() { + #[derive(JsonSchema)] + #[mcp_tool( + name = "calculator", + description = "Math tool", + title = "Scientific Calculator" + )] + struct Calc { + expression: String, + } + + let tool = Calc::tool(); + assert_eq!(tool.title.unwrap(), "Scientific Calculator"); +} + +#[test] +fn all_annotations_are_set() { + #[derive(JsonSchema)] + #[mcp_tool( + name = "delete_file", + description = "Deletes a file", + destructive_hint = true, + idempotent_hint = false, + open_world_hint = true, + read_only_hint = false + )] + struct DeleteFile { + path: String, + } + + let tool = DeleteFile::tool(); + let ann = tool.annotations.as_ref().unwrap(); + + assert!(ann.destructive_hint.unwrap()); + assert!(!ann.idempotent_hint.unwrap()); + assert!(ann.open_world_hint.unwrap()); + assert!(!ann.read_only_hint.unwrap()); +} + +#[test] +fn partial_annotations_some_set_some_not() { + #[derive(JsonSchema)] + #[mcp_tool( + name = "get_user", + description = "Fetch user", + read_only_hint = true, + idempotent_hint = true + )] + struct GetUser { + id: String, + } + + let tool = GetUser::tool(); + let ann = tool.annotations.as_ref().unwrap(); + + assert!(ann.read_only_hint.unwrap()); + assert!(ann.idempotent_hint.unwrap()); + assert!(ann.destructive_hint.is_none()); + assert!(ann.open_world_hint.is_none()); +} + +#[test] +fn execution_task_support_required() { + #[derive(JsonSchema)] + #[mcp_tool( + name = "long_task", + description = "desc", + execution(task_support = "required") + )] + struct LongTask { + data: String, + } + + let tool = LongTask::tool(); + let exec = tool.execution.as_ref().unwrap(); + assert_eq!(exec.task_support, Some(ToolExecutionTaskSupport::Required)); +} + +#[test] +fn execution_task_support_optional_and_forbidden() { + #[derive(JsonSchema)] + #[mcp_tool( + name = "quick_op", + description = "description", + execution(task_support = "optional") + )] + struct QuickOp { + value: i32, + } + + #[derive(JsonSchema)] + #[mcp_tool( + name = "no_task", + description = "description", + execution(task_support = "forbidden") + )] + struct NoTask { + flag: bool, + } + + assert_eq!( + QuickOp::tool().execution.unwrap().task_support, + Some(ToolExecutionTaskSupport::Optional) + ); + assert_eq!( + NoTask::tool().execution.unwrap().task_support, + Some(ToolExecutionTaskSupport::Forbidden) + ); +} + +// #[derive(JsonSchema)] +// #[mcp_tool( +// name = "icon_tool", +// icons = [ +// { src = "/icons/light.png", mime_type = "image/png", sizes = ["48x48", "96x96"], theme = "light" }, +// { src = "/icons/dark.svg", mime_type = "image/svg+xml", sizes = ["any"], theme = "dark" }, +// { src = "/icons/default.ico", sizes = ["32x32"] } // no mime/theme +// ] +// )] +// struct IconTool { +// input: String, +// } + +#[test] +fn icons_full_support() { + #[derive(JsonSchema)] + #[mcp_tool( + name = "icon_tool", + description="desc", + icons = [ + (src = "/icons/light.png", mime_type = "image/png", sizes = ["48x48", "96x96"], theme = "light" ), + ( src = "/icons/dark.svg", mime_type = "image/svg+xml", sizes = ["any"], theme = "dark" ), + ( src = "/icons/default.ico", sizes = ["32x32"] ) + ] + )] + struct IconTool { + input: String, + } + + let tool = IconTool::tool(); + let icons = &tool.icons; + + assert_eq!(icons.len(), 3); + + assert_eq!(icons[0].src, "/icons/light.png"); + assert_eq!(icons[0].mime_type.as_deref(), Some("image/png")); + assert_eq!(icons[0].sizes, vec!["48x48", "96x96"]); + assert_eq!(icons[0].theme, Some(IconTheme::Light)); + + assert_eq!(icons[1].src, "/icons/dark.svg"); + assert_eq!(icons[1].mime_type.as_deref(), Some("image/svg+xml")); + assert_eq!(icons[1].sizes, vec!["any"]); + assert_eq!(icons[1].theme, Some(IconTheme::Dark)); + + assert_eq!(icons[2].src, "/icons/default.ico"); + assert_eq!(icons[2].mime_type, None); + assert_eq!(icons[2].sizes, vec!["32x32"]); + assert_eq!(icons[2].theme, None); +} + +#[test] +fn icons_empty_when_not_provided() { + #[derive(JsonSchema)] + #[mcp_tool(name = "no_icons", description = "no_icons")] + struct NoIcons { + _x: i32, + } + assert!(NoIcons::tool().icons.is_empty()); +} + +#[test] +fn input_schema_has_correct_required_fields() { + #[derive(JsonSchema)] + #[mcp_tool(name = "user_create", description = "user_create")] + struct UserCreate { + username: String, + email: String, + age: Option, + tags: Vec, + } + + let tool: Tool = UserCreate::tool(); + let required = tool.input_schema.required; + assert!(required.contains(&"username".to_string())); + assert!(required.contains(&"email".to_string())); + assert!(required.contains(&"tags".to_string())); + assert!(!required.contains(&"age".to_string())); +} + +#[test] +fn properties_are_correctly_mapped() { + #[allow(unused)] + #[derive(JsonSchema)] + #[mcp_tool(name = "test_props", description = "test_props")] + struct TestProps { + name: String, + count: i32, + active: bool, + score: Option, + } + + let tool: Tool = TestProps::tool(); + let schema = tool.input_schema; + let props = schema.properties.unwrap(); + + assert!(props.contains_key("name")); + assert!(props.contains_key("count")); + assert!(props.contains_key("active")); + assert!(props.contains_key("score")); + + let name_prop = props.get("name").unwrap(); + assert_eq!(name_prop.get("type").unwrap().as_str().unwrap(), "string"); + + let active_prop = props.get("active").unwrap(); + assert_eq!( + active_prop.get("type").unwrap().as_str().unwrap(), + "boolean" + ); +} + +#[test] +fn tool_name_fallback_when_not_provided() { + #[derive(JsonSchema)] + #[mcp_tool(name = "fallback-name-tool", description = "No name, uses struct name")] + struct FallbackNameTool { + input: String, + } + + let tool: Tool = FallbackNameTool::tool(); + assert_eq!(tool.name, "fallback-name-tool"); // Uses struct name +} + +#[test] +fn meta_is_ignored_when_feature_off() { + // Should compile even if meta is provided + #[derive(JsonSchema)] + #[mcp_tool( + name = "old_schema", + description = "old_schema", + meta = r#"{"ignored": true}"# + )] + struct OldTool { + x: i32, + } + + let tool: Tool = OldTool::tool(); + + assert_eq!(tool.name, "old_schema"); + let meta = tool.meta.unwrap(); + assert_eq!(meta, json!({"ignored": true}).as_object().unwrap().clone()); +} + +#[test] +fn readme_example_tool() { + #[mcp_tool( + name = "write_file", + title = "Write File Tool", + description = "Create or overwrite a file with content.", + destructive_hint = false, + idempotent_hint = false, + open_world_hint = false, + read_only_hint = false, + execution(task_support = "optional"), + icons = [ + (src = "https:/mywebsite.com/write.png", mime_type = "image/png", sizes = ["128x128"], theme = "light"), + (src = "https:/mywebsite.com/write_dark.svg", mime_type = "image/svg+xml", sizes = ["64x64","128x128"], theme = "dark") + ], + meta = r#"{"key": "value"}"# + )] + #[derive(JsonSchema)] + pub struct WriteFileTool { + /// The target file's path. + pub path: String, + /// The string content to be written to the file + pub content: String, + } + + assert_eq!(WriteFileTool::tool_name(), "write_file"); + + let tool: rust_mcp_schema::Tool = WriteFileTool::tool(); + assert_eq!(tool.name, "write_file"); + assert_eq!(tool.title.as_ref().unwrap(), "Write File Tool"); + assert_eq!( + tool.description.unwrap(), + "Create or overwrite a file with content." + ); + + let icons = tool.icons; + assert_eq!(icons.len(), 2); + assert_eq!(icons[0].src, "https:/mywebsite.com/write.png"); + assert_eq!(icons[0].mime_type, Some("image/png".into())); + assert_eq!(icons[0].theme, Some("light".into())); + assert_eq!(icons[0].sizes, vec!["128x128"]); + assert_eq!(icons[1].mime_type, Some("image/svg+xml".into())); + + let meta: &serde_json::Map = tool.meta.as_ref().unwrap(); + assert_eq!( + meta.get("key").unwrap(), + &serde_json::Value::String("value".to_string()) + ); + + let schema_properties = tool.input_schema.properties.unwrap(); + assert_eq!(schema_properties.len(), 2); + assert!(schema_properties.contains_key("path")); + assert!(schema_properties.contains_key("content")); + + // get the `content` prop from schema + let content_prop = schema_properties.get("content").unwrap(); + + // assert the type + assert_eq!(content_prop.get("type").unwrap(), "string"); + // assert the description + assert_eq!( + content_prop.get("description").unwrap(), + "The string content to be written to the file" + ); + + let request_params = WriteFileTool::request_params().with_arguments( + json!({"path":"./test.txt","content":"hello tool"}) + .as_object() + .unwrap() + .clone(), + ); + + assert_eq!(request_params.name, "write_file"); +} diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 609b0ac..7652841 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -66,8 +66,7 @@ default = [ "sse", "streamable-http", "hyper-server", - "ssl", - "2025_06_18", + "ssl" ] # All features enabled by default sse = ["rust-mcp-transport/sse","http","http-body","http-body-util"] @@ -82,36 +81,6 @@ ssl = ["axum-server/tls-rustls"] tls-no-provider = ["axum-server/tls-rustls-no-provider"] macros = ["rust-mcp-macros/sdk"] -# enables mcp protocol version 2025-06-18 -2025-06-18 = [ - "rust-mcp-schema/2025_06_18", - "rust-mcp-macros/2025_06_18", - "rust-mcp-transport/2025_06_18", - "rust-mcp-schema/schema_utils", -] -# Alias: allow users to use underscores instead of hyphens -2025_06_18 = ["2025-06-18"] - -# enables mcp protocol version 2025_03_26 -2025-03-26 = [ - "rust-mcp-schema/2025_03_26", - "rust-mcp-macros/2025_03_26", - "rust-mcp-transport/2025_03_26", - "rust-mcp-schema/schema_utils", -] -# Alias: allow users to use underscores instead of hyphens -2025_03_26 = ["2025-03-26"] - - -# enables mcp protocol version 2024_11_05 -2024-11-05 = [ - "rust-mcp-schema/2024_11_05", - "rust-mcp-macros/2024_11_05", - "rust-mcp-transport/2024_11_05", - "rust-mcp-schema/schema_utils", -] -# Alias: allow users to use underscores instead of hyphens -2024_11_05 = ["2024-11-05"] [lints] workspace = true diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index d92d964..715f280 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -11,26 +11,21 @@ [Hello World MCP Server ](examples/hello-world-mcp-server-stdio) -A high-performance, asynchronous toolkit for building MCP servers and clients. -Focus on your app's logic while **rust-mcp-sdk** takes care of the rest! -**rust-mcp-sdk** provides the necessary components for developing both servers and clients in the MCP ecosystem. -Leveraging the [rust-mcp-schema](https://github.com/rust-mcp-stack/rust-mcp-schema) crate simplifies the process of building robust and reliable MCP servers and clients, ensuring consistency and minimizing errors in data handling and message processing. +A high-performance, asynchronous Rust toolkit for building MCP servers and clients. +Focus on your application logic - rust-mcp-sdk handles the protocol, transports, and the rest! +This SDK fully implements the latest MCP protocol version ([2025-11-25](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema)), with backward compatibility built-in. `rust-mcp-sdk` provides the necessary components for developing both servers and clients in the MCP ecosystem. It leverages the [rust-mcp-schema](https://crates.io/crates/rust-mcp-schema) crate for type-safe schema objects and includes powerful procedural macros for tools and user input elicitation. -**rust-mcp-sdk** supports all three official versions of the MCP protocol. -By default, it uses the **2025-06-18** version, but earlier versions can be enabled via Cargo features. - -πŸš€ The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. - - -**Features** -- βœ… Stdio, SSE and Streamable HTTP Support -- βœ… Supports multiple MCP protocol versions +**Key Features** +- βœ… Latest MCP protocol specification supported: 2025-11-25 +- βœ… Transports:Stdio, Streamable HTTP, and backward-compatible SSE support +- βœ… Lightweight Axum-based server for Streamable HTTP and SSE +- βœ… Multi-client concurrency - βœ… DNS Rebinding Protection +- βœ… Resumability - βœ… Batch Messages - βœ… Streaming & non-streaming JSON response -- βœ… Resumability - βœ… OAuth Authentication for MCP Servers - βœ… [Remote Oauth Provider](crates/rust-mcp-sdk/src/auth/auth_provider/remote_auth_provider.rs) (for any provider with DCR support) - βœ… **Keycloak** Provider (via [rust-mcp-extra](crates/rust-mcp-extra/README.md#keycloak)) @@ -41,24 +36,26 @@ By default, it uses the **2025-06-18** version, but earlier versions can be enab **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents -- [Getting Started](#getting-started) +- [Quick Start](#quick-start) + - [Minimal MCP Server (Stdio)]([#minimal-mcp-server-stdio](#minimal-mcp-server-stdio)) + - [Minimal MCP Server (Streamable HTTP)](#minimal-mcp-server-streamable-http) + - [Minimal MCP Client (Stdio)](#minimal-mcp-client-stdio) - [Usage Examples](#usage-examples) - - [MCP Server (stdio)](#mcp-server-stdio) - - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - - [MCP Client (stdio)](#mcp-client-stdio) - - [MCP Client (Streamable HTTP)](#mcp-client-streamable-http) - - [MCP Client (sse)](#mcp-client-sse) -- [Authentication](#authentication) - [Macros](#macros) + - [mcp_tool](#mcp_tool) + - [tool_box](#-tool_box) + - [mcp_icon](#-mcp_icon) +- [Authentication](#authentication) + - [RemoteAuthProvider](#remoteauthprovider) + - [OAuthProxy](#oauthproxy) - [HyperServerOptions](#hyperserveroptions) - - [Security Considerations](#security-considerations) +- [Security Considerations](#security-considerations) - [Cargo features](#cargo-features) - [Available Features](#available-features) - - [MCP protocol versions with corresponding features](#mcp-protocol-versions-with-corresponding-features) - [Default Features](#default-features) - [Using Only the server Features](#using-only-the-server-features) - [Using Only the client Features](#using-only-the-client-features) -- [Choosing Between Standard and Core Handlers traits](#choosing-between-standard-and-core-handlers-traits) +- [Handler Traits](#handlers-traits) - [Choosing Between **ServerHandler** and **ServerHandlerCore**](#choosing-between-serverhandler-and-serverhandlercore) - [Choosing Between **ClientHandler** and **ClientHandlerCore**](#choosing-between-clienthandler-and-clienthandlercore) - [Projects using Rust MCP SDK](#projects-using-rust-mcp-sdk) @@ -67,330 +64,339 @@ By default, it uses the **2025-06-18** version, but earlier versions can be enab - [License](#license) -## Getting Started -If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) -## Usage Examples +## Quick Start -### MCP Server (stdio) + -Create a MCP server with a `tool` that will print a `Hello World!` message: +Add to your Cargo.toml: +```toml +[dependencies] +rust-mcp-sdk = "0.9.0" # Check crates.io for the latest version +``` + + + +## Minimal MCP Server (Stdio) +```rs +use async_trait::async_trait; +use rust_mcp_sdk::{*,error::SdkResult,macros,mcp_server::{server_runtime, ServerHandler},schema::*,}; + +// Define a mcp tool +#[macros::mcp_tool(name = "say_hello", description = "returns \"Hello from Rust MCP SDK!\" message ")] +#[derive(Debug, ::serde::Deserialize, ::serde::Serialize, macros::JsonSchema)] +pub struct SayHelloTool {} + +// define a custom handler +#[derive(Default)] +struct HelloHandler; + +// implement ServerHandler +#[async_trait] +impl ServerHandler for HelloHandler { + // Handles requests to list available tools. + async fn handle_list_tools_request( + &self, + _request: Option, + _runtime: std::sync::Arc, + ) -> std::result::Result { + Ok(ListToolsResult { + tools: vec![SayHelloTool::tool()], + meta: None, + next_cursor: None, + }) + } + // Handles requests to call a specific tool. + async fn handle_call_tool_request(&self, + params: CallToolRequestParams, + _runtime: std::sync::Arc, + ) -> std::result::Result { + if params.name == "say_hello" { + Ok(CallToolResult::text_content(vec!["Hello from Rust MCP SDK!".into()])) + } else { + Err(CallToolError::unknown_tool(params.name)) + } + } +} -```rust #[tokio::main] async fn main() -> SdkResult<()> { - - // STEP 1: Define server details and capabilities - let server_details = InitializeResult { - // server name and version + // Define server details and capabilities + let server_info = InitializeResult { server_info: Implementation { - name: "Hello World MCP Server".to_string(), - version: "0.1.0".to_string(), - title: Some("Hello World MCP Server".to_string()), - }, - capabilities: ServerCapabilities { - // indicates that server support mcp tools - tools: Some(ServerCapabilitiesTools { list_changed: None }), - ..Default::default() // Using default values for other fields + name: "hello-rust-mcp".into(), + version: "0.1.0".into(), + title: Some("Hello World MCP Server".into()), + description: Some("A minimal Rust MCP server".into()), + icons: vec![mcp_icon!(src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "light")], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, - meta: None, - instructions: Some("server instructions...".to_string()), - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + capabilities: ServerCapabilities { tools: Some(ServerCapabilitiesTools { list_changed: None }), ..Default::default() }, + protocol_version: ProtocolVersion::V2025_11_25.into(), + instructions: None, + meta:None }; - // STEP 2: create a std transport with default options let transport = StdioTransport::new(TransportOptions::default())?; - - // STEP 3: instantiate our custom handler for handling MCP messages - let handler = MyServerHandler {}; - - // STEP 4: create a MCP server - let server: ServerRuntime = server_runtime::create_server(server_details, transport, handler); - - // STEP 5: Start the server + let handler = HelloHandler::default().to_mcp_server_handler(); + let server = server_runtime::create_server(server_info, transport, handler); server.start().await - } ``` -See hello-world-mcp-server-stdio example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : - -![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) - -### MCP Server (Streamable HTTP) - -Creating an MCP server in `rust-mcp-sdk` with the `sse` transport allows multiple clients to connect simultaneously with no additional setup. -Simply create a Hyper Server using `hyper_server::create_server()` and pass in the same handler and HyperServerOptions. - - -πŸ’‘ By default, both **Streamable HTTP** and **SSE** transports are enabled for backward compatibility. To disable the SSE transport , set the `sse_support` to false in the `HyperServerOptions`. - +## Minimal MCP Server (Streamable HTTP) +Creating an MCP server in `rust-mcp-sdk` allows multiple clients to connect simultaneously with no additional setup. +The setup is nearly identical to the stdio example shown above. You only need to create a Hyper server via `hyper_server::create_server()` and pass in the same handler and `HyperServerOptions`. +πŸ’‘ If backward compatibility is required, you can enable **SSE** transport by setting `sse_support` to true in `HyperServerOptions`. ```rust - -// STEP 1: Define server details and capabilities -let server_details = InitializeResult { - // server name and version - server_info: Implementation { - name: "Hello World MCP Server".to_string(), - version: "0.1.0".to_string(), - title: Some("Hello World MCP Server".to_string()), - }, - capabilities: ServerCapabilities { - // indicates that server support mcp tools - tools: Some(ServerCapabilitiesTools { list_changed: None }), - ..Default::default() // Using default values for other fields - }, - meta: None, - instructions: Some("server instructions...".to_string()), - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), +use async_trait::async_trait; +use rust_mcp_sdk::{*,error::SdkResult,event_store::InMemoryEventStore,macros, + mcp_server::{hyper_server, HyperServerOptions, ServerHandler},schema::*, }; -// STEP 2: instantiate our custom handler for handling MCP messages -let handler = MyServerHandler {}; - -// STEP 3: instantiate HyperServer, providing `server_details` , `handler` and HyperServerOptions -let server = hyper_server::create_server( - server_details, - handler, - HyperServerOptions { - host: "127.0.0.1".to_string(), - sse_support: false, - event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability - ..Default::default() - }, -); - -// STEP 4: Start the server -server.start().await?; - -Ok(()) -``` - - -The implementation of `MyServerHandler` is the same regardless of the transport used and could be as simple as the following: - -```rust - -// STEP 1: Define a rust_mcp_schema::Tool ( we need one with no parameters for this example) -#[mcp_tool(name = "say_hello_world", description = "Prints \"Hello World!\" message")] -#[derive(Debug, Deserialize, Serialize, JsonSchema)] +// Define a mcp tool +#[macros::mcp_tool( + name = "say_hello", + description = "returns \"Hello from Rust MCP SDK!\" message " +)] +#[derive(Debug, ::serde::Deserialize, ::serde::Serialize, macros::JsonSchema)] pub struct SayHelloTool {} -// STEP 2: Implement ServerHandler trait for a custom handler -// For this example , we only need handle_list_tools_request() and handle_call_tool_request() methods. -pub struct MyServerHandler; +// define a custom handler +#[derive(Default)] +struct HelloHandler; +// implement ServerHandler #[async_trait] -impl ServerHandler for MyServerHandler { - // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { - - Ok(ListToolsResult { - tools: vec![SayHelloTool::tool()], - meta: None, - next_cursor: None, - }) - +impl ServerHandler for HelloHandler { + // Handles requests to list available tools. + async fn handle_list_tools_request( + &self, + _request: Option, + _runtime: std::sync::Arc, + ) -> std::result::Result { + Ok(ListToolsResult {tools: vec![SayHelloTool::tool()],meta: None,next_cursor: None}) } - - /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { - - if request.tool_name() == SayHelloTool::tool_name() { - Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) + // Handles requests to call a specific tool. + async fn handle_call_tool_request( + &self, + params: CallToolRequestParams, + _runtime: std::sync::Arc, + ) -> std::result::Result { + if params.name == "say_hello" {Ok(CallToolResult::text_content(vec!["Hello from Rust MCP SDK!".into()])) } else { - Err(CallToolError::unknown_tool(request.tool_name().to_string())) + Err(CallToolError::unknown_tool(params.name)) } - } } -``` - ---- - -πŸ‘‰ For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** -See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +#[tokio::main] +async fn main() -> SdkResult<()> { + // Define server details and capabilities + let server_info = InitializeResult { + server_info: Implementation { + name: "hello-rust-mcp".into(), + version: "0.1.0".into(), + title: Some("Hello World MCP Server".into()), + description: Some("A minimal Rust MCP server".into()), + icons: vec![mcp_icon!(src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "light")], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), + }, + capabilities: ServerCapabilities { tools: Some(ServerCapabilitiesTools { list_changed: None }), ..Default::default() }, + protocol_version: ProtocolVersion::V2025_11_25.into(), + instructions: None, + meta:None + }; -![mcp-server in rust](assets/examples/hello-world-server-streamable-http.gif) + let handler = HelloHandler::default().to_mcp_server_handler(); + let server = hyper_server::create_server( + server_info, + handler, + HyperServerOptions { + host: "127.0.0.1".to_string(), + event_store: Some(std::sync::Arc::new(InMemoryEventStore::default())), // enable resumability + ..Default::default() + }, + ); + server.start().await?; + Ok(()) +} +``` ---- -### MCP Client (stdio) +## Minimal MCP Client (Stdio) +Following is implementation of an MCP client that starts the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, displays the server's name, version, and list of tools provided by the server. -Create an MCP client that starts the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, displays the server's name, version, and list of tools, then uses the add tool provided by the server to sum 120 and 28, printing the result. ```rust +use async_trait::async_trait; +use rust_mcp_sdk::{*, error::SdkResult, + mcp_client::{client_runtime, ClientHandler}, + schema::*, +}; -// STEP 1: Custom Handler to handle incoming MCP Messages +// Custom Handler to handle incoming MCP Messages pub struct MyClientHandler; - #[async_trait] impl ClientHandler for MyClientHandler { - // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs + // To see all the trait methods you can override, + // check out: + // https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs } #[tokio::main] async fn main() -> SdkResult<()> { - - // Step2 : Define client details and capabilities + // Client details and capabilities let client_details: InitializeRequestParams = InitializeRequestParams { capabilities: ClientCapabilities::default(), client_info: Implementation { name: "simple-rust-mcp-client".into(), version: "0.1.0".into(), + description: None, + icons: vec![], + title: None, + website_url: None, }, - protocol_version: LATEST_PROTOCOL_VERSION.into(), + protocol_version: ProtocolVersion::V2025_11_25.into(), + meta: None, }; - // Step3 : Create a transport, with options to launch @modelcontextprotocol/server-everything MCP Server + // Create a transport, with options to launch @modelcontextprotocol/server-everything MCP Server let transport = StdioTransport::create_with_server_launch( - "npx", - vec![ "-y".to_string(), "@modelcontextprotocol/server-everything".to_string()], - None, TransportOptions::default() + "npx",vec!["-y".to_string(),"@modelcontextprotocol/server-everything@latest".to_string()], + None, + TransportOptions::default(), )?; - // STEP 4: instantiate our custom handler for handling MCP messages + // instantiate our custom handler for handling MCP messages let handler = MyClientHandler {}; - // STEP 5: create a MCP client - let client = client_runtime::create_client(client_details, transport, handler); - - // STEP 6: start the MCP client + // Create and start the MCP client + let client = client_runtime::create_client(client_details, transport, handler); client.clone().start().await?; + // use client methods to communicate with the MCP Server as you wish: - // STEP 7: use client methods to communicate with the MCP Server as you wish - + let server_version = client.server_version().unwrap(); + // Retrieve and display the list of tools available on the server - let server_version = client.server_version().unwrap(); - let tools = client.list_tools(None).await?.tools; - - println!("List of tools for {}@{}", server_version.name, server_version.version); - + let tools = client.request_tool_list(None).await?.tools; + println!( "List of tools for {}@{}",server_version.name, server_version.version); tools.iter().enumerate().for_each(|(tool_index, tool)| { - println!(" {}. {} : {}", - tool_index + 1, - tool.name, - tool.description.clone().unwrap_or_default() - ); + println!(" {}. {} : {}", tool_index + 1, tool.name, tool.description.clone().unwrap_or_default()); }); - println!("Call \"add\" tool with 100 and 28 ..."); - // Create a `Map` to represent the tool parameters - let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); - let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; - - // invoke the tool - let result = client.call_tool(request).await?; - - println!("{}",result.content.first().unwrap().as_text_content()?.text); - client.shut_down().await?; - Ok(()) } - ``` -Here is the output : +## Usage Examples -![rust-mcp-sdk-client-output](assets/examples/mcp-client-sample-code.jpg) +πŸ‘‰ For full examples (stdio, Streamable HTTP, clients, auth, etc.), see the [examples/](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples) directory. -> your results may vary slightly depending on the version of the MCP Server in use when you run it. +πŸ‘‰ If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) -### MCP Client (Streamable HTTP) -```rs +See [hello-world-mcp-server-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : -// STEP 1: Custom Handler to handle incoming MCP Messages -pub struct MyClientHandler; +![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) -#[async_trait] -impl ClientHandler for MyClientHandler { - // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs -} -#[tokio::main] -async fn main() -> SdkResult<()> { +## Macros +Enable with the `macros` feature. - // Step2 : Define client details and capabilities - let client_details: InitializeRequestParams = InitializeRequestParams { - capabilities: ClientCapabilities::default(), - client_info: Implementation { - name: "simple-rust-mcp-client-sse".to_string(), - version: "0.1.0".to_string(), - title: Some("Simple Rust MCP Client (SSE)".to_string()), - }, - protocol_version: LATEST_PROTOCOL_VERSION.into(), - }; +[rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) includes several helpful macros that simplify common tasks when building MCP servers and clients. For example, they can automatically generate tool specifications and tool schemas right from your structs, or assist with elicitation requests and responses making them completely type safe. - // Step 3: Create transport options to connect to an MCP server via Streamable HTTP. - let transport_options = StreamableTransportOptions { - mcp_url: MCP_SERVER_URL.to_string(), - request_options: RequestOptions { - ..RequestOptions::default() - }, - }; +### β—Ύ`mcp_tool` +Generate a [Tool](https://docs.rs/rust-mcp-schema/latest/rust_mcp_schema/struct.Tool.html) from a struct, with rich metadata (icons, execution hints, etc.). - // STEP 4: instantiate the custom handler that is responsible for handling MCP messages - let handler = MyClientHandler {}; +example usage: +```rs +#[mcp_tool( + name = "write_file", + title = "Write File Tool", + description = "Create a new file or completely overwrite an existing file with new content.", + destructive_hint = false idempotent_hint = false open_world_hint = false read_only_hint = false, + meta = r#"{ "key" : "value", "string_meta" : "meta value", "numeric_meta" : 15}"#, + execution(task_support = "optional"), + icons = [(src = "https:/website.com/write.png", mime_type = "image/png", sizes = ["128x128"], theme = "light")] +)] +#[derive(rust_mcp_macros::JsonSchema)] +pub struct WriteFileTool { + /// The target file's path for writing content. + pub path: String, + /// The string content to be written to the file + pub content: String, +} +``` - // STEP 5: create the client with transport options and the handler - let client = client_runtime::with_transport_options(client_details, transport_options, handler); +πŸ“ For complete documentation, example usage, and a list of all available attributes, please refer to https://crates.io/crates/rust-mcp-macros. - // STEP 6: start the MCP client - client.clone().start().await?; +### β—Ύ `tool_box!()` +Automatically generates an enum based on the provided list of tools, making it easier to organize and manage them, especially when your application includes a large number of tools. - // STEP 7: use client methods to communicate with the MCP Server as you wish +```rs +tool_box!(GreetingTools, [SayHelloTool, SayGoodbyeTool]); - // Retrieve and display the list of tools available on the server - let server_version = client.server_version().unwrap(); - let tools = client.list_tools(None).await?.tools; - println!("List of tools for {}@{}", server_version.name, server_version.version); +let tools: Vec = GreetingTools::tools(); +`` - tools.iter().enumerate().for_each(|(tool_index, tool)| { - println!(" {}. {} : {}", - tool_index + 1, - tool.name, - tool.description.clone().unwrap_or_default() - ); - }); +πŸ’» For a real-world example, check out [tools/](https://github.com/rust-mcp-stack/rust-mcp-filesystem/tree/main/src/tools) and +[handle_call_tool_request(...)](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L195) in [rust-mcp-filesystem](https://github.com/rust-mcp-stack/rust-mcp-filesystem) project - println!("Call \"add\" tool with 100 and 28 ..."); - // Create a `Map` to represent the tool parameters - let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); - let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; +### β—Ύ [mcp_elicit](https://crates.io/crates/rust-mcp-macros) +Generates type-safe elicitation (Form or URL mode) for user input. - // invoke the tool - let result = client.call_tool(request).await?; +example usage: +```rs +#[mcp_elicit(message = "Please enter your info", mode = form)] +#[derive(JsonSchema)] +pub struct UserInfo { + #[json_schema(title = "Name", min_length = 5, max_length = 100)] + pub name: String, + #[json_schema(title = "Email", format = "email")] + pub email: Option, + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + #[json_schema(title = "Tags")] + pub tags: Vec, +} - println!("{}",result.content.first().unwrap().as_text_content()?.text); +// Sends a request to the client asking the user to provide input +let result: ElicitResult = server.request_elicitation(UserInfo::elicit_request_params()).await?; - client.shut_down().await?; +// Convert result.content into a UserInfo instance +let user_info = UserInfo::from_elicit_result_content(result.content)?; - Ok(()) +println!("name: {}", user_info.name); +println!("age: {}", user_info.age); +println!("email: {}",user.email.clone().unwrap_or("not provider".into())); +println!("tags: {}", user_info.tags.join(",")); ``` -πŸ‘‰ see [examples/simple-mcp-client-streamable-http](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-streamable-http) for a complete working example. +πŸ“ For complete documentation, example usage, and a list of all available attributes, please refer to https://crates.io/crates/rust-mcp-macros. +### β—Ύ `mcp_icon!()` +A convenient icon builder for implementations and tools, offering full attribute support including theme, size, mime, and more. -### MCP Client (sse) -Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical to the [stdio example](#mcp-client-stdio) , with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: - -```diff -- let transport = StdioTransport::create_with_server_launch( -- "npx", -- vec![ "-y".to_string(), "@modelcontextprotocol/server-everything".to_string()], -- None, TransportOptions::default() --)?; -+ let transport = ClientSseTransport::new(MCP_SERVER_URL, ClientSseTransportOptions::default())?; +example usage: +```rs +let icon: crate::schema::Icon = mcp_icon!( + src = "http://website.com/icon.png", + mime_type = "image/png", + sizes = ["64x64"], + theme = "dark" + ); ``` -πŸ‘‰ see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. - - ## Authentication MCP server can verify tokens issued by other systems, integrate with external identity providers, or manage the entire authentication process itself. Each option offers a different balance of simplicity, security, and control. @@ -404,120 +410,12 @@ MCP server can verify tokens issued by other systems, integrate with external id - [WorkOS autn example](crates/rust-mcp-extra/README.md#workos-authkit) - ### OAuthProxy OAuthProxy enables authentication with OAuth providers that don’t support Dynamic Client Registration (DCR).It accepts any client registration request, handles the DCR on your server side and then uses your pre-registered app credentials upstream.The proxy also forwards callbacks, allowing dynamic redirect URIs to work with providers that require fixed ones. > ⚠️ OAuthProxy support is still in development, please use RemoteAuthProvider for now. - - -## Macros -[rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) includes several helpful macros that simplify common tasks when building MCP servers and clients. For example, they can automatically generate tool specifications and tool schemas right from your structs, or assist with elicitation requests and responses making them completely type safe. - -> To use these macros, ensure the `macros` feature is enabled in your Cargo.toml. - -### mcp_tool -`mcp_tool` is a procedural macro attribute that helps generating rust_mcp_schema::Tool from a struct. - -Usage example: -```rust -#[mcp_tool( - name = "move_file", - title="Move File", - description = concat!("Move or rename files and directories. Can move files between directories ", -"and rename them in a single operation. If the destination exists, the ", -"operation will fail. Works across different directories and can be used ", -"for simple renaming within the same directory. ", -"Both source and destination must be within allowed directories."), - destructive_hint = false, - idempotent_hint = false, - open_world_hint = false, - read_only_hint = false -)] -#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug, JsonSchema)] -pub struct MoveFileTool { - /// The source path of the file to move. - pub source: String, - /// The destination path to move the file to. - pub destination: String, -} - -// Now we can call `tool()` method on it to get a Tool instance -let rust_mcp_sdk::schema::Tool = MoveFileTool::tool(); - -``` - -πŸ’» For a real-world example, check out any of the tools available at: https://github.com/rust-mcp-stack/rust-mcp-filesystem/tree/main/src/tools - - -### tool_box -`tool_box` generates an enum from a provided list of tools, making it easier to organize and manage them, especially when your application includes a large number of tools. - -It accepts an array of tools and generates an enum where each tool becomes a variant of the enum. - -Generated enum has a `tools()` function that returns a `Vec` , and a `TryFrom` trait implementation that could be used to convert a ToolRequest into a Tool instance. - -Usage example: -```rust - // Accepts an array of tools and generates an enum named `FileSystemTools`, - // where each tool becomes a variant of the enum. - tool_box!(FileSystemTools, [ReadFileTool, MoveFileTool, SearchFilesTool]); - - // now in the app, we can use the FileSystemTools, like: - let all_tools: Vec = FileSystemTools::tools(); -``` - -πŸ’» To see a real-world example of that please see : -- `tool_box` macro usage: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs) -- using `tools()` in list tools request : [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L67) -- using `try_from` in call tool_request: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L100) - - - -### mcp_elicit -The `mcp_elicit` macro generates implementations for the annotated struct to facilitate data elicitation. It enables struct to generate `ElicitRequestedSchema` and also parsing a map of field names to `ElicitResultContentValue` values back into the struct, supporting both required and optional fields. The generated implementation includes: -- A `message()` method returning the elicitation message as a string. -- A `requested_schema()` method returning an `ElicitRequestedSchema` based on the struct’s JSON schema. -- A `from_content_map()` method to convert a map of `ElicitResultContentValue` values into a struct instance. - -### Attributes - -- `message` - An optional string (or `concat!(...)` expression) to prompt the user or system for input. Defaults to an empty string if not provided. - -Usage example: -```rust -// A struct that could be used to send elicit request and get the input from the user -#[mcp_elicit(message = "Please enter your info")] -#[derive(JsonSchema)] -pub struct UserInfo { - #[json_schema( - title = "Name", - description = "The user's full name", - min_length = 5, - max_length = 100 - )] - pub name: String, - /// Is user a student? - #[json_schema(title = "Is student?", default = true)] - pub is_student: Option, - - /// User's favorite color - pub favorate_color: Colors, -} - -// send a Elicit Request , ask for UserInfo data and convert the result back to a valid UserInfo instance -let result: ElicitResult = server - .elicit_input(UserInfo::message(), UserInfo::requested_schema()) - .await?; - -// Create a UserInfo instance using data provided by the user on the client side -let user_info = UserInfo::from_content_map(result.content)?; - -``` -πŸ’» For mre info please see : -- https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/crates/rust-mcp-macros ## HyperServerOptions @@ -531,89 +429,22 @@ A typical example of creating a HyperServer that exposes the MCP server via Stre let server = hyper_server::create_server( server_details, - handler, + handler.to_mcp_server_handler(), HyperServerOptions { host: "127.0.0.1".to_string(), - enable_ssl: true, + port: 8080, + event_store: Some(std::sync::Arc::new(InMemoryEventStore::default())), // enable resumability + auth: Some(Arc::new(auth_provider)), // enable authentication + sse_support: false, ..Default::default() }, ); server.start().await?; - ``` -Here is a list of available options with descriptions for configuring the HyperServer: -```rs - -pub struct HyperServerOptions { - /// Hostname or IP address the server will bind to (default: "127.0.0.1") - pub host: String, - - /// Hostname or IP address the server will bind to (default: "8080") - pub port: u16, - - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>>, - - /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) - pub custom_streamable_http_endpoint: Option, - - /// Shared transport configuration used by the server - pub transport_options: Arc, - - /// Event store for resumability support - /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages - pub event_store: Option>, - - /// This setting only applies to streamable HTTP. - /// If true, the server will return JSON responses instead of starting an SSE stream. - /// This can be useful for simple request/response scenarios without streaming. - /// Default is false (SSE streams are preferred). - pub enable_json_response: Option, - - /// Interval between automatic ping messages sent to clients to detect disconnects - pub ping_interval: Duration, - - /// Enables SSL/TLS if set to `true` - pub enable_ssl: bool, - - /// Path to the SSL/TLS certificate file (e.g., "cert.pem"). - /// Required if `enable_ssl` is `true`. - pub ssl_cert_path: Option, - - /// Path to the SSL/TLS private key file (e.g., "key.pem"). - /// Required if `enable_ssl` is `true`. - pub ssl_key_path: Option, +πŸ“ Refer to [HyperServerOptions](https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/hyper_servers/server.rs#L43) for a complete overview of HyperServerOptions attributes and options. - /// List of allowed host header values for DNS rebinding protection. - /// If not specified, host validation is disabled. - pub allowed_hosts: Option>, - - /// List of allowed origin header values for DNS rebinding protection. - /// If not specified, origin validation is disabled. - pub allowed_origins: Option>, - - /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). - /// Default is false for backwards compatibility. - pub dns_rebinding_protection: bool, - - /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) - pub sse_support: bool, - - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - /// Applicable only if sse_support is true - pub custom_sse_endpoint: Option, - - /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) - /// Applicable only if sse_support is true - pub custom_messages_endpoint: Option, - - /// Optional authentication provider for protecting MCP server. - pub auth: Option>, -} - -``` ### Security Considerations @@ -637,28 +468,19 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. - `sse`: Enables support for the `Server-Sent Events (SSE)` transport. - `streamable-http`: Enables support for the `Streamable HTTP` transport. - - `stdio`: Enables support for the `standard input/output (stdio)` transport. - `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. -#### MCP Protocol Versions with Corresponding Features - -- `2025_06_18` : Activates MCP Protocol version 2025-06-18 (enabled by default) -- `2025_03_26` : Activates MCP Protocol version 2025-03-26 -- `2024_11_05` : Activates MCP Protocol version 2024-11-05 - -> Note: MCP protocol versions are mutually exclusive-only one can be active at any given time. - ### Default Features -When you add rust-mcp-sdk as a dependency without specifying any features, all features are included, with the latest MCP Protocol version enabled by default: +When you add rust-mcp-sdk as a dependency without specifying any features, all features are enabled by default ```toml [dependencies] -rust-mcp-sdk = "0.2.0" +rust-mcp-sdk = "0.9.0" ``` diff --git a/crates/rust-mcp-sdk/examples/quick_start.rs b/crates/rust-mcp-sdk/examples/quick_start.rs new file mode 100644 index 0000000..dfc8999 --- /dev/null +++ b/crates/rust-mcp-sdk/examples/quick_start.rs @@ -0,0 +1,77 @@ +use async_trait::async_trait; +use rust_mcp_sdk::{ + error::SdkResult, + macros, + mcp_server::{server_runtime, ServerHandler}, + schema::*, + *, +}; + +// Define a mcp tool +#[macros::mcp_tool( + name = "say_hello", + description = "returns \"Hello from Rust MCP SDK!\" message " +)] +#[derive(Debug, ::serde::Deserialize, ::serde::Serialize, macros::JsonSchema)] +pub struct SayHelloTool {} + +// define a custom handler +#[derive(Default)] +struct HelloHandler {} + +// implement ServerHandler +#[async_trait] +impl ServerHandler for HelloHandler { + // Handles requests to list available tools. + async fn handle_list_tools_request( + &self, + _request: Option, + _runtime: std::sync::Arc, + ) -> std::result::Result { + Ok(ListToolsResult { + tools: vec![SayHelloTool::tool()], + meta: None, + next_cursor: None, + }) + } + // Handles requests to call a specific tool. + async fn handle_call_tool_request( + &self, + params: CallToolRequestParams, + _runtime: std::sync::Arc, + ) -> std::result::Result { + if params.name == "say_hello" { + Ok(CallToolResult::text_content(vec![ + "Hello from Rust MCP SDK!".into(), + ])) + } else { + Err(CallToolError::unknown_tool(params.name)) + } + } +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + let server_info = InitializeResult { + server_info: Implementation { + name: "hello-rust-mcp".into(), + version: "0.1.0".into(), + title: Some("Hello World MCP Server".into()), + description: Some("A minimal Rust MCP server".into()), + icons: vec![mcp_icon!(src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "light")], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), + }, + capabilities: ServerCapabilities { tools: Some(ServerCapabilitiesTools { list_changed: None }), ..Default::default() }, + protocol_version: ProtocolVersion::V2025_11_25.into(), + instructions: None, + meta:None + }; + + let transport = StdioTransport::new(TransportOptions::default())?; + let handler = HelloHandler::default().to_mcp_server_handler(); + let server = server_runtime::create_server(server_info, transport, handler); + server.start().await +} diff --git a/crates/rust-mcp-sdk/examples/quick_start_client_stdio.rs b/crates/rust-mcp-sdk/examples/quick_start_client_stdio.rs new file mode 100644 index 0000000..377fe4e --- /dev/null +++ b/crates/rust-mcp-sdk/examples/quick_start_client_stdio.rs @@ -0,0 +1,92 @@ +use async_trait::async_trait; +use rust_mcp_sdk::{ + error::SdkResult, + mcp_client::{client_runtime, ClientHandler}, + schema::*, + *, +}; + +// Custom Handler to handle incoming MCP Messages +pub struct MyClientHandler; +#[async_trait] +impl ClientHandler for MyClientHandler { + // To see all the trait methods you can override, + // check out: + // https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + // Client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client".into(), + version: "0.1.0".into(), + description: None, + icons: vec![], + title: None, + website_url: None, + }, + protocol_version: ProtocolVersion::V2025_11_25.into(), + meta: None, + }; + + // Create a transport, with options to launch @modelcontextprotocol/server-everything MCP Server + let transport = StdioTransport::create_with_server_launch( + "npx", + vec![ + "-y".to_string(), + "@modelcontextprotocol/server-everything@latest".to_string(), + ], + None, + TransportOptions::default(), + )?; + + // instantiate our custom handler for handling MCP messages + let handler = MyClientHandler {}; + + // Create and start the MCP client + let client = client_runtime::create_client(client_details, transport, handler); + client.clone().start().await?; + + // use client methods to communicate with the MCP Server as you wish: + + let server_version = client.server_version().unwrap(); + + // Retrieve and display the list of tools available on the server + let tools = client.request_tool_list(None).await?.tools; + println!( + "List of tools for {}@{}", + server_version.name, server_version.version + ); + tools.iter().enumerate().for_each(|(tool_index, tool)| { + println!( + " {}. {} : {}", + tool_index + 1, + tool.name, + tool.description.clone().unwrap_or_default() + ); + }); + + println!("Call \"add\" tool with 100 and 28 ..."); + let params = serde_json::json!({"a": 100,"b": 28}) + .as_object() + .unwrap() + .clone(); + let request = CallToolRequestParams { + name: "add".to_string(), + arguments: Some(params), + meta: None, + task: None, + }; + // invoke the tool + let result = client.request_tool_call(request).await?; + println!( + "{}", + result.content.first().unwrap().as_text_content()?.text + ); + + client.shut_down().await?; + Ok(()) +} diff --git a/crates/rust-mcp-sdk/examples/quick_start_streamable_http.rs b/crates/rust-mcp-sdk/examples/quick_start_streamable_http.rs new file mode 100644 index 0000000..bb9b5ca --- /dev/null +++ b/crates/rust-mcp-sdk/examples/quick_start_streamable_http.rs @@ -0,0 +1,87 @@ +use async_trait::async_trait; +use rust_mcp_sdk::{ + error::SdkResult, + event_store::InMemoryEventStore, + macros, + mcp_server::{hyper_server, HyperServerOptions, ServerHandler}, + schema::*, + *, +}; + +// Define a mcp tool +#[macros::mcp_tool( + name = "say_hello", + description = "returns \"Hello from Rust MCP SDK!\" message " +)] +#[derive(Debug, ::serde::Deserialize, ::serde::Serialize, macros::JsonSchema)] +pub struct SayHelloTool {} + +// define a custom handler +#[derive(Default)] +struct HelloHandler {} + +// implement ServerHandler +#[async_trait] +impl ServerHandler for HelloHandler { + // Handles requests to list available tools. + async fn handle_list_tools_request( + &self, + _request: Option, + _runtime: std::sync::Arc, + ) -> std::result::Result { + Ok(ListToolsResult { + tools: vec![SayHelloTool::tool()], + meta: None, + next_cursor: None, + }) + } + // Handles requests to call a specific tool. + async fn handle_call_tool_request( + &self, + params: CallToolRequestParams, + _runtime: std::sync::Arc, + ) -> std::result::Result { + if params.name == "say_hello" { + Ok(CallToolResult::text_content(vec![ + "Hello from Rust MCP SDK!".into(), + ])) + } else { + Err(CallToolError::unknown_tool(params.name)) + } + } +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + // Define server details and capabilities + let server_info = InitializeResult { + server_info: Implementation { + name: "hello-rust-mcp".into(), + version: "0.1.0".into(), + title: Some("Hello World MCP Server".into()), + description: Some("A minimal Rust MCP server".into()), + icons: vec![mcp_icon!(src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "light")], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), + }, + capabilities: ServerCapabilities { tools: Some(ServerCapabilitiesTools { list_changed: None }), ..Default::default() }, + protocol_version: ProtocolVersion::V2025_11_25.into(), + instructions: None, + meta:None + }; + + let handler = HelloHandler::default().to_mcp_server_handler(); + let server = hyper_server::create_server( + server_info, + handler, + HyperServerOptions { + host: "127.0.0.1".to_string(), + event_store: Some(std::sync::Arc::new(InMemoryEventStore::default())), // enable resumability + ..Default::default() + }, + ); + server.start().await?; + Ok(()) +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs index 5cedb59..4b41592 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs @@ -1,30 +1,33 @@ use std::{sync::Arc, time::Duration}; +use crate::{ + error::SdkResult, + mcp_server::{ + error::{TransportServerError, TransportServerResult}, + ServerRuntime, + }, +}; use crate::{ mcp_http::McpAppState, mcp_server::HyperServer, schema::{ schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, - CreateMessageRequestParams, CreateMessageResult, InitializeRequestParams, - ListRootsRequestParams, ListRootsResult, LoggingMessageNotificationParams, - PromptListChangedNotificationParams, ResourceListChangedNotificationParams, - ResourceUpdatedNotificationParams, ToolListChangedNotificationParams, + CreateMessageRequestParams, CreateMessageResult, InitializeRequestParams, ListRootsResult, + LoggingMessageNotificationParams, NotificationParams, RequestParams, + ResourceUpdatedNotificationParams, }, McpServer, }; - use axum_server::Handle; +use rust_mcp_schema::{ + schema_utils::{CustomNotification, CustomRequest}, + CancelTaskParams, CancelTaskResult, CancelledNotificationParams, ElicitCompleteParams, + ElicitRequestParams, ElicitResult, GenericResult, GetTaskParams, GetTaskPayloadParams, + GetTaskResult, ProgressNotificationParams, TaskStatusNotificationParams, +}; use rust_mcp_transport::SessionId; use tokio::task::JoinHandle; -use crate::{ - error::SdkResult, - mcp_server::{ - error::{TransportServerError, TransportServerResult}, - ServerRuntime, - }, -}; - pub struct HyperRuntime { pub(crate) state: Arc, pub(crate) server_task: JoinHandle>, @@ -85,6 +88,11 @@ impl HyperRuntime { ) } + /// Sends a request to the client and processes the response. + /// + /// This function sends a `RequestFromServer` message to the client, waits for the response, + /// and handles the result. If the response is empty or of an invalid type, an error is returned. + /// Otherwise, it returns the result from the client. pub async fn send_request( &self, session_id: &SessionId, @@ -104,115 +112,317 @@ impl HyperRuntime { runtime.send_notification(notification).await } + pub async fn client_info( + &self, + session_id: &SessionId, + ) -> SdkResult> { + let runtime = self.runtime_by_session(session_id).await?; + Ok(runtime.client_info()) + } + + /******************* + Requests + *******************/ + + /// Sends an elicitation request to the client to prompt user input and returns the received response. + /// + /// The requested_schema argument allows servers to define the structure of the expected response using a restricted subset of JSON Schema. + /// To simplify client user experience, elicitation schemas are limited to flat objects with primitive properties only + pub async fn request_elicitation( + &self, + session_id: &SessionId, + params: ElicitRequestParams, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + runtime.request_elicitation(params).await + } + /// Request a list of root URIs from the client. Roots allow /// servers to ask for specific directories or files to operate on. A common example /// for roots is providing a set of repositories or directories a server should operate on. /// This request is typically used when the server needs to understand the file system /// structure or access specific locations that the client has permission to read from - pub async fn list_roots( + pub async fn request_root_list( &self, session_id: &SessionId, - params: Option, + params: Option, ) -> SdkResult { let runtime = self.runtime_by_session(session_id).await?; - runtime.list_roots(params).await + runtime.request_root_list(params).await } - pub async fn send_logging_message( + /// A ping request to check that the other party is still alive. + /// The receiver must promptly respond, or else may be disconnected. + /// + /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response + /// Once the response is received, it attempts to convert it into the expected + /// result type. + /// + /// # Returns + /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful. + /// If the request or conversion fails, an error is returned. + pub async fn ping( + &self, + session_id: &SessionId, + params: Option, + timeout: Option, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + runtime.ping(params, timeout).await + } + + /// A request from the server to sample an LLM via the client. + /// The client has full discretion over which model to select. + /// The client should also inform the user before beginning sampling, + /// to allow them to inspect the request (human in the loop) + /// and decide whether to approve it. + pub async fn request_message_creation( + &self, + session_id: &SessionId, + params: CreateMessageRequestParams, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + runtime.request_message_creation(params).await + } + + ///Send a request to retrieve the state of a task. + pub async fn request_get_task( + &self, + session_id: &SessionId, + params: GetTaskParams, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + runtime.request_get_task(params).await + } + + ///Send a request to retrieve the result of a completed task. + pub async fn request_get_task_payload( + &self, + session_id: &SessionId, + params: GetTaskPayloadParams, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + runtime.request_get_task_payload(params).await + } + + ///Send a request to cancel a task. + pub async fn request_task_cancellation( + &self, + session_id: &SessionId, + params: CancelTaskParams, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + runtime.request_task_cancellation(params).await + } + + ///Send a custom request with a custom method name and params + pub async fn request_custom( + &self, + session_id: &SessionId, + params: CustomRequest, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + runtime.request_custom(params).await + } + + /******************* + Notifications + *******************/ + + /// Send log message notification from server to client. + /// If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically. + pub async fn notify_log_message( &self, session_id: &SessionId, params: LoggingMessageNotificationParams, ) -> SdkResult<()> { let runtime = self.runtime_by_session(session_id).await?; - runtime.send_logging_message(params).await + runtime.notify_log_message(params).await } - /// An optional notification from the server to the client, informing it that + ///Send an optional notification from the server to the client, informing it that /// the list of prompts it offers has changed. /// This may be issued by servers without any previous subscription from the client. - pub async fn send_prompt_list_changed( + pub async fn notify_prompt_list_changed( &self, session_id: &SessionId, - params: Option, + params: Option, ) -> SdkResult<()> { let runtime = self.runtime_by_session(session_id).await?; - runtime.send_prompt_list_changed(params).await + runtime.notify_prompt_list_changed(params).await } - /// An optional notification from the server to the client, + ///Send an optional notification from the server to the client, /// informing it that the list of resources it can read from has changed. /// This may be issued by servers without any previous subscription from the client. - pub async fn send_resource_list_changed( + pub async fn notify_resource_list_changed( &self, session_id: &SessionId, - params: Option, + params: Option, ) -> SdkResult<()> { let runtime = self.runtime_by_session(session_id).await?; - runtime.send_resource_list_changed(params).await + runtime.notify_resource_list_changed(params).await } - /// A notification from the server to the client, informing it that + ///Send a notification from the server to the client, informing it that /// a resource has changed and may need to be read again. /// This should only be sent if the client previously sent a resources/subscribe request. - pub async fn send_resource_updated( + pub async fn notify_resource_updated( &self, session_id: &SessionId, params: ResourceUpdatedNotificationParams, ) -> SdkResult<()> { let runtime = self.runtime_by_session(session_id).await?; - runtime.send_resource_updated(params).await + runtime.notify_resource_updated(params).await } - /// An optional notification from the server to the client, informing it that + ///Send an optional notification from the server to the client, informing it that /// the list of tools it offers has changed. /// This may be issued by servers without any previous subscription from the client. - pub async fn send_tool_list_changed( + pub async fn notify_tool_list_changed( &self, session_id: &SessionId, - params: Option, + params: Option, ) -> SdkResult<()> { let runtime = self.runtime_by_session(session_id).await?; - runtime.send_tool_list_changed(params).await + runtime.notify_tool_list_changed(params).await } - /// A ping request to check that the other party is still alive. - /// The receiver must promptly respond, or else may be disconnected. - /// - /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response - /// Once the response is received, it attempts to convert it into the expected - /// result type. - /// - /// # Returns - /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful. - /// If the request or conversion fails, an error is returned. - pub async fn ping( + /// This notification can be sent to indicate that it is cancelling a previously-issued request. + /// The request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification MAY arrive after the request has already finished. + /// This notification indicates that the result will be unused, so any associated processing SHOULD cease. + /// A client MUST NOT attempt to cancel its initialize request. + /// For task cancellation, use the tasks/cancel request instead of this notification. + pub async fn notify_cancellation( &self, session_id: &SessionId, - timeout: Option, - ) -> SdkResult { + params: CancelledNotificationParams, + ) -> SdkResult<()> { let runtime = self.runtime_by_session(session_id).await?; - runtime.ping(timeout).await + runtime.notify_cancellation(params).await } - /// A request from the server to sample an LLM via the client. - /// The client has full discretion over which model to select. - /// The client should also inform the user before beginning sampling, - /// to allow them to inspect the request (human in the loop) - /// and decide whether to approve it. + ///Send an out-of-band notification used to inform the receiver of a progress update for a long-running request. + pub async fn notify_progress( + &self, + session_id: &SessionId, + params: ProgressNotificationParams, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + runtime.notify_progress(params).await + } + + /// Send an optional notification from the receiver to the requestor, informing them that a task's status has changed. + /// Receivers are not required to send these notifications. + pub async fn notify_task_status( + &self, + session_id: &SessionId, + params: TaskStatusNotificationParams, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + runtime.notify_task_status(params).await + } + + ///An optional notification from the server to the client, informing it of a completion of a out-of-band elicitation request. + pub async fn notify_elicitation_completed( + &self, + session_id: &SessionId, + params: ElicitCompleteParams, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + runtime.notify_elicitation_completed(params).await + } + + ///Send a custom notification + pub async fn notify_custom( + &self, + session_id: &SessionId, + params: CustomNotification, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + runtime.notify_custom(params).await + } + + #[deprecated(since = "0.8.0", note = "Use `request_root_list()` instead.")] + pub async fn list_roots( + &self, + session_id: &SessionId, + params: Option, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + runtime.request_root_list(params).await + } + + #[deprecated(since = "0.8.0", note = "Use `request_elicitation()` instead.")] + pub async fn elicit_input( + &self, + session_id: &SessionId, + params: ElicitRequestParams, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + runtime.request_elicitation(params).await + } + + #[deprecated(since = "0.8.0", note = "Use `request_message_creation()` instead.")] pub async fn create_message( &self, session_id: &SessionId, params: CreateMessageRequestParams, ) -> SdkResult { let runtime = self.runtime_by_session(session_id).await?; - runtime.create_message(params).await + runtime.request_message_creation(params).await } - pub async fn client_info( + #[deprecated(since = "0.8.0", note = "Use `notify_tool_list_changed()` instead.")] + pub async fn send_tool_list_changed( &self, session_id: &SessionId, - ) -> SdkResult> { + params: Option, + ) -> SdkResult<()> { let runtime = self.runtime_by_session(session_id).await?; - Ok(runtime.client_info()) + runtime.notify_tool_list_changed(params).await + } + + #[deprecated(since = "0.8.0", note = "Use `notify_resource_updated()` instead.")] + pub async fn send_resource_updated( + &self, + session_id: &SessionId, + params: ResourceUpdatedNotificationParams, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + runtime.notify_resource_updated(params).await + } + + #[deprecated( + since = "0.8.0", + note = "Use `notify_resource_list_changed()` instead." + )] + pub async fn send_resource_list_changed( + &self, + session_id: &SessionId, + params: Option, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + runtime.notify_resource_list_changed(params).await + } + + #[deprecated(since = "0.8.0", note = "Use `notify_prompt_list_changed()` instead.")] + pub async fn send_prompt_list_changed( + &self, + session_id: &SessionId, + params: Option, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + runtime.notify_prompt_list_changed(params).await + } + + #[deprecated(since = "0.8.0", note = "Use `notify_log_message()` instead.")] + pub async fn send_logging_message( + &self, + session_id: &SessionId, + params: LoggingMessageNotificationParams, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + runtime.notify_log_message(params).await } } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/hyper_server.rs b/crates/rust-mcp-sdk/src/hyper_servers/hyper_server.rs index b85b8ba..986fa22 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/hyper_server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/hyper_server.rs @@ -1,10 +1,7 @@ -use std::sync::Arc; - -use crate::schema::InitializeResult; - -use crate::mcp_server::{server_runtime::ServerRuntimeInternalHandler, ServerHandler}; - use super::{HyperServer, HyperServerOptions}; +use crate::mcp_traits::McpServerHandler; +use crate::schema::InitializeResult; +use std::sync::Arc; /// Creates a new HyperServer instance with the provided handler and options /// The handler must implement ServerHandler. @@ -18,12 +15,8 @@ use super::{HyperServer, HyperServerOptions}; /// * `HyperServer` - A configured HyperServer instance ready to start pub fn create_server( server_details: InitializeResult, - handler: impl ServerHandler, + handler: Arc, server_options: HyperServerOptions, ) -> HyperServer { - HyperServer::new( - server_details, - Arc::new(ServerRuntimeInternalHandler::new(Box::new(handler))), - server_options, - ) + HyperServer::new(server_details, handler, server_options) } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/hyper_server_core.rs b/crates/rust-mcp-sdk/src/hyper_servers/hyper_server_core.rs index 9599134..6fb57ee 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/hyper_server_core.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/hyper_server_core.rs @@ -1,5 +1,5 @@ use super::{HyperServer, HyperServerOptions}; -use crate::mcp_server::{server_runtime_core::RuntimeCoreInternalHandler, ServerHandlerCore}; +use crate::mcp_traits::McpServerHandler; use crate::schema::InitializeResult; use std::sync::Arc; @@ -15,12 +15,8 @@ use std::sync::Arc; /// * `HyperServer` - A configured HyperServer instance ready to start pub fn create_server( server_details: InitializeResult, - handler: impl ServerHandlerCore, + handler: Arc, server_options: HyperServerOptions, ) -> HyperServer { - HyperServer::new( - server_details, - Arc::new(RuntimeCoreInternalHandler::new(Box::new(handler))), - server_options, - ) + HyperServer::new(server_details, handler, server_options) } diff --git a/crates/rust-mcp-sdk/src/lib.rs b/crates/rust-mcp-sdk/src/lib.rs index 1d6476a..42dab07 100644 --- a/crates/rust-mcp-sdk/src/lib.rs +++ b/crates/rust-mcp-sdk/src/lib.rs @@ -68,7 +68,7 @@ pub mod mcp_server { //! handle each message based on its type and parameters. //! //! Refer to [examples/hello-world-mcp-server-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core) for an example. - pub use super::mcp_handlers::mcp_server_handler::{ServerHandler, ToMcpServerHandler}; + pub use super::mcp_handlers::mcp_server_handler::ServerHandler; pub use super::mcp_handlers::mcp_server_handler_core::ServerHandlerCore; pub use super::mcp_runtimes::server_runtime::mcp_server_runtime as server_runtime; @@ -83,6 +83,7 @@ pub mod mcp_server { #[cfg(feature = "hyper-server")] pub use super::mcp_http::{McpAppState, McpHttpHandler}; + pub use super::mcp_traits::{McpServerHandler, ToMcpServerHandler, ToMcpServerHandlerCore}; } pub mod auth; @@ -96,4 +97,8 @@ pub mod macros { } pub mod id_generator; -pub mod schema; + +pub mod schema { + pub use rust_mcp_schema::schema_utils::*; + pub use rust_mcp_schema::*; +} diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs index e78db9a..f3f3d30 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs @@ -1,143 +1,238 @@ +use crate::mcp_traits::McpClient; +use crate::schema::schema_utils::{CustomNotification, CustomRequest}; use crate::schema::{ - CancelledNotification, CreateMessageRequest, CreateMessageResult, ListRootsRequest, - ListRootsResult, LoggingMessageNotification, PingRequest, ProgressNotification, - PromptListChangedNotification, ResourceListChangedNotification, ResourceUpdatedNotification, - Result, RpcError, ToolListChangedNotification, + CancelTaskParams, CancelTaskRequest, CancelTaskResult, CancelledNotificationParams, + CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult, ElicitCompleteParams, + ElicitRequest, ElicitRequestParams, ElicitResult, GenericResult, GetTaskParams, + GetTaskPayloadParams, GetTaskPayloadRequest, GetTaskRequest, GetTaskResult, ListRootsRequest, + ListRootsResult, ListTasksRequest, ListTasksResult, LoggingMessageNotificationParams, + NotificationParams, PaginatedRequestParams, ProgressNotificationParams, RequestParams, + ResourceUpdatedNotificationParams, Result, RpcError, TaskStatusNotificationParams, }; -#[cfg(feature = "2025_06_18")] -use crate::schema::{ElicitRequest, ElicitResult}; - use async_trait::async_trait; -use serde_json::Value; - -use crate::mcp_traits::McpClient; -/// Defines the `ClientHandler` trait for handling Model Context Protocol (MCP) operations on a client. -/// This trait provides default implementations for request and notification handlers in an MCP client, -/// allowing developers to override methods for custom behavior. +/// The `ClientHandler` trait defines how a client handles Model Context Protocol (MCP) operations. +/// It includes default implementations for handling requests , notifications and errors and must be +/// extended or overridden by developers to customize client behavior. #[allow(unused)] #[async_trait] pub trait ClientHandler: Send + Sync + 'static { //**********************// //** Request Handlers **// //**********************// + + /// Handles a ping, to check that the other party is still alive. + /// The receiver must promptly respond, or else may be disconnected. async fn handle_ping_request( &self, - request: PingRequest, + params: Option, runtime: &dyn McpClient, ) -> std::result::Result { Ok(Result::default()) } + /// Handles a request from the server to sample an LLM via the client. + /// The client has full discretion over which model to select. + /// The client should also inform the user before beginning sampling, + /// to allow them to inspect the request (human in the loop) and decide whether to approve it. async fn handle_create_message_request( &self, - request: CreateMessageRequest, + params: CreateMessageRequestParams, runtime: &dyn McpClient, ) -> std::result::Result { - runtime.assert_client_request_capabilities(request.method())?; + runtime.assert_client_request_capabilities(CreateMessageRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + CreateMessageRequest::method_value() ))) } + /// Handles a request from the server to request a list of root URIs from the client. Roots allow + /// servers to ask for specific directories or files to operate on. + /// This request is typically used when the server needs to understand the file system + /// structure or access specific locations that the client has permission to read from. async fn handle_list_roots_request( &self, - request: ListRootsRequest, + params: Option, runtime: &dyn McpClient, ) -> std::result::Result { - runtime.assert_client_request_capabilities(request.method())?; + runtime.assert_client_request_capabilities(ListRootsRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + ListRootsRequest::method_value(), ))) } - #[cfg(feature = "2025_06_18")] + ///Handles a request from the server to elicit additional information from the user via the client. async fn handle_elicit_request( &self, - request: ElicitRequest, + params: ElicitRequestParams, runtime: &dyn McpClient, ) -> std::result::Result { - runtime.assert_client_request_capabilities(request.method())?; + runtime.assert_client_request_capabilities(ElicitRequest::method_value())?; + Err(RpcError::method_not_found().with_message(format!( + "No handler is implemented for '{}'.", + ElicitRequest::method_value() + ))) + } + + /// Handles a request to retrieve the state of a task. + async fn handle_get_task_request( + &self, + params: GetTaskParams, + runtime: &dyn McpClient, + ) -> std::result::Result { + runtime.assert_client_request_capabilities(GetTaskRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + GetTaskRequest::method_value() ))) } + /// Handles a request to retrieve the result of a completed task. + async fn handle_get_task_payload_request( + &self, + params: GetTaskPayloadParams, + runtime: &dyn McpClient, + ) -> std::result::Result { + runtime.assert_client_request_capabilities(GetTaskPayloadRequest::method_value())?; + Err(RpcError::method_not_found().with_message(format!( + "No handler is implemented for '{}'.", + GetTaskPayloadRequest::method_value() + ))) + } + + /// Handles a request to cancel a task. + async fn handle_cancel_task_request( + &self, + params: CancelTaskParams, + runtime: &dyn McpClient, + ) -> std::result::Result { + runtime.assert_client_request_capabilities(CancelTaskRequest::method_value())?; + Err(RpcError::method_not_found().with_message(format!( + "No handler is implemented for '{}'.", + CancelTaskRequest::method_value() + ))) + } + + /// Handles a request to retrieve a list of tasks. + async fn handle_list_tasks_request( + &self, + params: Option, + runtime: &dyn McpClient, + ) -> std::result::Result { + runtime.assert_client_request_capabilities(ListTasksRequest::method_value())?; + Err(RpcError::method_not_found().with_message(format!( + "No handler is implemented for '{}'.", + ListTasksRequest::method_value() + ))) + } + + /// Handle a custom request async fn handle_custom_request( &self, - request: Value, + request: CustomRequest, runtime: &dyn McpClient, ) -> std::result::Result { - Err(RpcError::method_not_found() - .with_message("No handler is implemented for custom requests.".to_string())) + Err(RpcError::method_not_found().with_message(format!( + "No handler for custom request : \"{}\"", + request.method + ))) } //***************************// //** Notification Handlers **// //***************************// + /// Handles a notification that indicates that it is cancelling a previously-issued request. + /// it is always possible that this notification MAY arrive after the request has already finished. + /// This notification indicates that the result will be unused, so any associated processing SHOULD cease. async fn handle_cancelled_notification( &self, - notification: CancelledNotification, + params: CancelledNotificationParams, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) } + /// Handles an out-of-band notification used to inform the receiver of a progress update for a long-running request. async fn handle_progress_notification( &self, - notification: ProgressNotification, + params: ProgressNotificationParams, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) } + /// Handles a notification from the server to the client, informing it that the list of resources it can read from has changed. async fn handle_resource_list_changed_notification( &self, - notification: ResourceListChangedNotification, + params: Option, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) } + /// handles a notification from the server to the client, informing it that a resource has changed and may need to be read again. async fn handle_resource_updated_notification( &self, - notification: ResourceUpdatedNotification, + params: ResourceUpdatedNotificationParams, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) } + ///Handles a notification from the server to the client, informing it that the list of prompts it offers has changed. async fn handle_prompt_list_changed_notification( &self, - notification: PromptListChangedNotification, + params: Option, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) } + /// Handles a notification from the server to the client, informing it that the list of tools it offers has changed. async fn handle_tool_list_changed_notification( &self, - notification: ToolListChangedNotification, + params: Option, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) } + /// Handles notification of a log message passed from server to client. + /// If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically. async fn handle_logging_message_notification( &self, - notification: LoggingMessageNotification, + params: LoggingMessageNotificationParams, + runtime: &dyn McpClient, + ) -> std::result::Result<(), RpcError> { + Ok(()) + } + + /// Handles a notification from the receiver to the requestor, informing them that a task's status has changed. + /// Receivers are not required to send these notifications. + async fn handle_task_status_notification( + &self, + params: TaskStatusNotificationParams, + runtime: &dyn McpClient, + ) -> std::result::Result<(), RpcError> { + Ok(()) + } + + /// Handles a notification from the server to the client, informing it of a completion of a out-of-band elicitation request. + async fn handle_elicitation_complete_notification( + &self, + params: ElicitCompleteParams, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) } + /// Handles a custom notification message async fn handle_custom_notification( &self, - notification: Value, + notification: CustomNotification, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs index 59444b0..ef9e55e 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs @@ -1,4 +1,3 @@ -use crate::schema::schema_utils::*; use crate::schema::*; use async_trait::async_trait; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index 0a51967..a59d5b7 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -1,17 +1,19 @@ use crate::{ mcp_server::server_runtime::ServerRuntimeInternalHandler, - mcp_traits::McpServerHandler, - schema::{schema_utils::CallToolError, *}, + mcp_traits::{McpServerHandler, ToMcpServerHandler}, + schema::{ + schema_utils::{CallToolError, CustomNotification, CustomRequest}, + *, + }, }; use async_trait::async_trait; -use serde_json::Value; use std::sync::Arc; use crate::{mcp_traits::McpServer, utils::enforce_compatible_protocol_version}; -/// Defines the `ServerHandler` trait for handling Model Context Protocol (MCP) operations on a server. -/// This trait provides default implementations for request and notification handlers in an MCP server, -/// allowing developers to override methods for custom behavior. +/// The `ServerHandler` trait defines how a server handles Model Context Protocol (MCP) operations. +/// It provides default implementations for request , notification and error handlers, and must be extended or +/// overridden by developers to customize server behavior. #[allow(unused)] #[async_trait] pub trait ServerHandler: Send + Sync + 'static { @@ -33,20 +35,20 @@ pub trait ServerHandler: Send + Sync + 'static { /// Do not override this unless the standard initialization process doesn't work for you or you need to modify it. async fn handle_initialize_request( &self, - initialize_request: InitializeRequest, + params: InitializeRequestParams, runtime: Arc, ) -> std::result::Result { let mut server_info = runtime.server_info().to_owned(); // Provide compatibility for clients using older MCP protocol versions. if let Some(updated_protocol_version) = enforce_compatible_protocol_version( - &initialize_request.params.protocol_version, + ¶ms.protocol_version, &server_info.protocol_version, ) .map_err(|err| { tracing::error!( "Incompatible protocol version : client: {} server: {}", - &initialize_request.params.protocol_version, + ¶ms.protocol_version, &server_info.protocol_version ); RpcError::internal_error().with_message(err.to_string()) @@ -55,7 +57,7 @@ pub trait ServerHandler: Send + Sync + 'static { } runtime - .set_client_details(initialize_request.params.clone()) + .set_client_details(params) .await .map_err(|err| RpcError::internal_error().with_message(format!("{err}")))?; @@ -69,8 +71,8 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_ping_request( &self, - _: PingRequest, - _: Arc, + _params: Option, + _runtime: Arc, ) -> std::result::Result { Ok(Result::default()) } @@ -81,13 +83,13 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_list_resources_request( &self, - request: ListResourcesRequest, + params: Option, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(ListResourcesRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + &ListResourcesRequest::method_value(), ))) } @@ -97,13 +99,13 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_list_resource_templates_request( &self, - request: ListResourceTemplatesRequest, + params: Option, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(ListResourceTemplatesRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + &ListResourceTemplatesRequest::method_value(), ))) } @@ -113,13 +115,13 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_read_resource_request( &self, - request: ReadResourceRequest, + params: ReadResourceRequestParams, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(ReadResourceRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + &ReadResourceRequest::method_value(), ))) } @@ -129,13 +131,13 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_subscribe_request( &self, - request: SubscribeRequest, + params: SubscribeRequestParams, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(SubscribeRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + &SubscribeRequest::method_value(), ))) } @@ -145,13 +147,13 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_unsubscribe_request( &self, - request: UnsubscribeRequest, + params: UnsubscribeRequestParams, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(UnsubscribeRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + &UnsubscribeRequest::method_value(), ))) } @@ -161,13 +163,13 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_list_prompts_request( &self, - request: ListPromptsRequest, + params: Option, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(ListPromptsRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + &ListPromptsRequest::method_value(), ))) } @@ -177,13 +179,13 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_get_prompt_request( &self, - request: GetPromptRequest, + params: GetPromptRequestParams, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(GetPromptRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + &GetPromptRequest::method_value(), ))) } @@ -193,13 +195,13 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_list_tools_request( &self, - request: ListToolsRequest, + params: Option, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(ListToolsRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + &ListToolsRequest::method_value(), ))) } @@ -209,13 +211,13 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_call_tool_request( &self, - request: CallToolRequest, + params: CallToolRequestParams, runtime: Arc, ) -> std::result::Result { runtime - .assert_server_request_capabilities(request.method()) + .assert_server_request_capabilities(CallToolRequest::method_value()) .map_err(CallToolError::new)?; - Ok(CallToolError::unknown_tool(format!("Unknown tool: {}", request.params.name)).into()) + Ok(CallToolError::unknown_tool(format!("Unknown tool: {}", params.name)).into()) } /// Handles requests to enable or adjust logging level. @@ -224,13 +226,13 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_set_level_request( &self, - request: SetLevelRequest, + params: SetLevelRequestParams, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(SetLevelRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + &SetLevelRequest::method_value(), ))) } @@ -240,13 +242,65 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_complete_request( &self, - request: CompleteRequest, + params: CompleteRequestParams, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(CompleteRequest::method_value())?; Err(RpcError::method_not_found().with_message(format!( "No handler is implemented for '{}'.", - request.method(), + &CompleteRequest::method_value(), + ))) + } + + ///Handles a request to retrieve the state of a task. + async fn handle_get_task_request( + &self, + params: GetTaskParams, + runtime: Arc, + ) -> std::result::Result { + runtime.assert_server_request_capabilities(GetTaskRequest::method_value())?; + Err(RpcError::method_not_found().with_message(format!( + "No handler is implemented for '{}'.", + &GetTaskRequest::method_value(), + ))) + } + + /// Handles a request to retrieve the result of a completed task. + async fn handle_get_task_payload_request( + &self, + params: GetTaskPayloadParams, + runtime: Arc, + ) -> std::result::Result { + runtime.assert_server_request_capabilities(GetTaskPayloadRequest::method_value())?; + Err(RpcError::method_not_found().with_message(format!( + "No handler is implemented for '{}'.", + &GetTaskPayloadRequest::method_value(), + ))) + } + + /// Handles a request to cancel a task. + async fn handle_cancel_task_request( + &self, + params: CancelTaskParams, + runtime: Arc, + ) -> std::result::Result { + runtime.assert_server_request_capabilities(CancelTaskRequest::method_value())?; + Err(RpcError::method_not_found().with_message(format!( + "No handler is implemented for '{}'.", + &CancelTaskRequest::method_value(), + ))) + } + + /// Handles a request to retrieve a list of tasks. + async fn handle_list_task_request( + &self, + params: Option, + runtime: Arc, + ) -> std::result::Result { + runtime.assert_server_request_capabilities(ListTasksRequest::method_value())?; + Err(RpcError::method_not_found().with_message(format!( + "No handler is implemented for '{}'.", + &ListTasksRequest::method_value(), ))) } @@ -256,9 +310,9 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_custom_request( &self, - request: Value, + request: CustomRequest, runtime: Arc, - ) -> std::result::Result { + ) -> std::result::Result { Err(RpcError::method_not_found() .with_message("No handler is implemented for custom requests.".to_string())) } @@ -269,7 +323,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_initialized_notification( &self, - notification: InitializedNotification, + params: Option, runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) @@ -279,7 +333,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_cancelled_notification( &self, - notification: CancelledNotification, + params: CancelledNotificationParams, runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) @@ -289,7 +343,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_progress_notification( &self, - notification: ProgressNotification, + params: ProgressNotificationParams, runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) @@ -299,7 +353,16 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_roots_list_changed_notification( &self, - notification: RootsListChangedNotification, + params: Option, + runtime: Arc, + ) -> std::result::Result<(), RpcError> { + Ok(()) + } + + ///handles a notification from the receiver to the requestor, informing them that a task's status has changed. + async fn handle_task_status_notification( + &self, + params: TaskStatusNotificationParams, runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) @@ -309,7 +372,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_custom_notification( &self, - notification: Value, + notification: CustomNotification, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -331,11 +394,6 @@ pub trait ServerHandler: Send + Sync + 'static { } } -// Custom trait for conversion -pub trait ToMcpServerHandler { - fn to_mcp_server_handler(self) -> Arc; -} - impl ToMcpServerHandler for T { fn to_mcp_server_handler(self) -> Arc { Arc::new(ServerRuntimeInternalHandler::new(Box::new(self))) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index c89e403..098d081 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -1,5 +1,6 @@ +use crate::mcp_server::server_runtime_core::RuntimeCoreInternalHandler; use crate::mcp_traits::McpServer; -use crate::schema::schema_utils::*; +use crate::mcp_traits::{McpServerHandler, ToMcpServerHandlerCore}; use crate::schema::*; use async_trait::async_trait; use std::sync::Arc; @@ -49,3 +50,9 @@ pub trait ServerHandlerCore: Send + Sync + 'static { runtime: Arc, ) -> std::result::Result<(), RpcError>; } + +impl ToMcpServerHandlerCore for T { + fn to_mcp_server_handler(self) -> Arc { + Arc::new(RuntimeCoreInternalHandler::new(Box::new(self))) + } +} diff --git a/crates/rust-mcp-sdk/src/mcp_http/http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/http_utils.rs index 52509d3..48608b8 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/http_utils.rs @@ -343,7 +343,7 @@ pub(crate) async fn create_standalone_stream( runtime.update_auth_info(auth_info).await; - if runtime.stream_id_exists(DEFAULT_STREAM_ID).await { + if runtime.default_stream_exists().await { let error = SdkError::bad_request().with_message("Only one SSE stream is allowed per session"); return error_response(StatusCode::CONFLICT, error) @@ -549,10 +549,7 @@ pub(crate) async fn process_incoming_message( }; if is_result { - match runtime - .consume_payload_string(DEFAULT_STREAM_ID, payload) - .await - { + match runtime.consume_payload_string(payload).await { Ok(()) => { let body = Full::new(Bytes::new()) .map_err(|err| TransportServerError::HttpError(err.to_string())) diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs index 75ffcc3..943ded3 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -9,7 +9,6 @@ use crate::auth::AuthInfo; use crate::auth::AuthProvider; use crate::mcp_http::{middleware::compose, BoxFutureResponse, Middleware, RequestHandler}; use crate::mcp_http::{GenericBodyExt, RequestExt}; -use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID; use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::SdkError; use crate::{ @@ -263,7 +262,7 @@ impl McpHttpHandler { let message = request.body(); transmit - .consume_payload_string(DEFAULT_STREAM_ID, message.as_ref()) + .consume_payload_string(message.as_ref()) .await .map_err(|err| { tracing::trace!("{}", err); diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware.rs index c8637e0..4e1760c 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/middleware.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware.rs @@ -52,6 +52,7 @@ where #[cfg(test)] mod tests { use super::*; + use crate::mcp_icon; use crate::schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities}; use crate::{ id_generator::{FastIdGenerator, UuidGenerator}, @@ -90,6 +91,14 @@ mod tests { name: "server".to_string(), title: None, version: "0.1.0".to_string(), + description: Some("test Server, by Rust MCP SDK".to_string()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".to_string()), }, }), handler: handler.to_mcp_server_handler(), diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware/auth_middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware/auth_middleware.rs index f6de197..c8410af 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/middleware/auth_middleware.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware/auth_middleware.rs @@ -165,6 +165,7 @@ impl Middleware for AuthMiddleware { mod tests { use super::*; use crate::auth::AuthMetadataBuilder; + use crate::mcp_icon; use crate::schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities}; use crate::{ auth::{OauthTokenVerifier, RemoteAuthProvider}, @@ -290,6 +291,14 @@ mod tests { name: "server".to_string(), title: None, version: "0.1.0".to_string(), + description: Some("Auth Middleware Test Server, by Rust MCP SDK".to_string()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".to_string()), }, }), handler: handler.to_mcp_server_handler(), diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs index 08bbba1..279b55c 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs @@ -367,6 +367,7 @@ mod tests { use crate::{ id_generator::{FastIdGenerator, UuidGenerator}, mcp_http::{types::GenericBodyExt, MiddlewareNext}, + mcp_icon, mcp_server::{ServerHandler, ToMcpServerHandler}, schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities}, session_store::InMemorySessionStore, @@ -396,6 +397,14 @@ mod tests { name: "server".to_string(), title: None, version: "0.1.0".to_string(), + description: Some("test server, by Rust MCP SDK".to_string()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".to_string()), }, }), handler: handler.to_mcp_server_handler(), diff --git a/crates/rust-mcp-sdk/src/mcp_macros.rs b/crates/rust-mcp-sdk/src/mcp_macros.rs index d7e7f4a..7213891 100644 --- a/crates/rust-mcp-sdk/src/mcp_macros.rs +++ b/crates/rust-mcp-sdk/src/mcp_macros.rs @@ -1 +1,2 @@ +pub mod mcp_icon; pub mod tool_box; diff --git a/crates/rust-mcp-sdk/src/mcp_macros/mcp_icon.rs b/crates/rust-mcp-sdk/src/mcp_macros/mcp_icon.rs new file mode 100644 index 0000000..87fc7c7 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_macros/mcp_icon.rs @@ -0,0 +1,151 @@ +#[macro_export] +/// Macro to conveniently create an `Icon` instance using a shorthand syntax". +/// +/// # Syntax +/// ```text +/// mcp_icon!( +/// src = "path_or_url", +/// mime_type = "optional_mime_type", +/// sizes = ["WxH", "WxH", ...], +/// theme = "dark" | "light", +/// ) +/// ``` +/// +/// # Rules +/// - `src` is **mandatory**. +/// - `mime_type`, `sizes`, and `theme` are **optional**. +/// - If `theme` is missing or invalid, it defaults to `IconTheme::Light`. +/// - `sizes` uses a Rust array of string literals (DSL style), which are converted to `Vec`. +/// +/// # Example +/// ```rust +/// let my_icon: rust_mcp_sdk::schema::Icon = rust_mcp_sdk::mcp_icon!( +/// src = "/icons/dark.png", +/// mime_type = "image/png", +/// sizes = ["128x128", "256x256"], +/// theme = "dark" +/// ); +/// ``` +macro_rules! mcp_icon { + ( + src = $src:expr + $(, mime_type = $mime_type:expr )? + $(, sizes = [$($size:expr),* $(,)?] )? + $(, theme = $theme:expr )? + $(,)? + ) => { + $crate::schema::Icon { + src: $src.into(), + mime_type: None $(.or(Some($mime_type.into())))?, + sizes: vec![$($($size.into()),*)?], + theme: None $(.or(Some($theme.into())))?, + } + }; +} + +#[cfg(test)] +mod tests { + use crate::schema::*; + + // Helper function to convert IconTheme to &str for easy comparisons + fn theme_str(theme: Option) -> &'static str { + match theme { + Some(IconTheme::Dark) => "dark", + Some(IconTheme::Light) => "light", + None => "none", + } + } + + #[test] + fn test_minimal_icon() { + // Only mandatory src + let icon = mcp_icon!(src = "/icons/simple.png"); + assert_eq!(icon.src, "/icons/simple.png"); + assert!(icon.mime_type.is_none()); + assert!(icon.sizes.is_empty()); + assert!(icon.theme.is_none()); + } + + #[test] + fn test_icon_with_mime_type() { + let icon = mcp_icon!(src = "/icons/simple.png", mime_type = "image/png"); + assert_eq!(icon.src, "/icons/simple.png"); + assert_eq!(icon.mime_type.as_deref(), Some("image/png")); + assert!(icon.sizes.is_empty()); + assert!(icon.theme.is_none()); + } + + #[test] + fn test_icon_with_sizes() { + let icon = mcp_icon!(src = "/icons/simple.png", sizes = ["32x32", "64x64"]); + assert_eq!(icon.src, "/icons/simple.png"); + assert!(icon.mime_type.is_none()); + assert_eq!(icon.sizes, vec!["32x32", "64x64"]); + assert!(icon.theme.is_none()); + } + + #[test] + fn test_icon_with_theme_light() { + let icon = mcp_icon!(src = "/icons/simple.png", theme = "light"); + assert_eq!(icon.src, "/icons/simple.png"); + assert!(icon.mime_type.is_none()); + assert!(icon.sizes.is_empty()); + assert_eq!(theme_str(icon.theme), "light"); + } + + #[test] + fn test_icon_with_theme_dark() { + let icon = mcp_icon!(src = "/icons/simple.png", theme = "dark"); + assert_eq!(theme_str(icon.theme), "dark"); + } + + #[test] + fn test_icon_with_invalid_theme_defaults_to_light() { + let icon = mcp_icon!(src = "/icons/simple.png", theme = "foo"); + // Invalid theme should default to Light + assert_eq!(theme_str(icon.theme), "light"); + } + + #[test] + fn test_icon_full() { + let icon = mcp_icon!( + src = "/icons/full.png", + mime_type = "image/png", + sizes = ["16x16", "32x32", "64x64"], + theme = "dark" + ); + + assert_eq!(icon.src, "/icons/full.png"); + assert_eq!(icon.mime_type.as_deref(), Some("image/png")); + assert_eq!(icon.sizes, vec!["16x16", "32x32", "64x64"]); + assert_eq!(theme_str(icon.theme), "dark"); + } + + #[test] + fn test_icon_sizes_empty_when_missing() { + let icon = mcp_icon!(src = "/icons/empty.png"); + assert!(icon.sizes.is_empty()); + } + + #[test] + fn test_icon_optional_fields_missing() { + let icon = mcp_icon!(src = "/icons/missing.png"); + assert!(icon.mime_type.is_none()); + assert!(icon.sizes.is_empty()); + assert!(icon.theme.is_none()); + } + + #[test] + fn test_icon_trailing_comma() { + let icon = mcp_icon!( + src = "/icons/comma.png", + mime_type = "image/jpeg", + sizes = ["48x48"], + theme = "light", + ); + assert_eq!(icon.src, "/icons/comma.png"); + assert_eq!(icon.mime_type.as_deref(), Some("image/jpeg")); + assert_eq!(icon.sizes, vec!["48x48"]); + assert_eq!(theme_str(icon.theme), "light"); + } +} diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 70a18d2..af2d594 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -9,10 +9,9 @@ use crate::{ schema::{ schema_utils::{ self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient, - ServerMessage, ServerMessages, + NotificationFromClient, RequestFromClient, ServerMessage, ServerMessages, }, - InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, - RequestId, RpcError, ServerResult, + InitializeRequestParams, InitializeResult, RequestId, RpcError, ServerResult, }, }; use async_trait::async_trait; @@ -21,8 +20,8 @@ use futures::StreamExt; #[cfg(feature = "streamable-http")] use rust_mcp_transport::{ClientStreamableTransport, StreamableTransportOptions}; -use rust_mcp_transport::{IoStream, SessionId, StreamId, Transport, TransportDispatcher}; -use std::{collections::HashMap, sync::Arc, time::Duration}; +use rust_mcp_transport::{IoStream, SessionId, StreamId, TransportDispatcher}; +use std::{sync::Arc, time::Duration}; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::sync::{watch, Mutex}; @@ -40,7 +39,7 @@ type TransportType = Arc; pub struct ClientRuntime { // A thread-safe map storing transport types - transport_map: tokio::sync::RwLock>, + transport_map: tokio::sync::RwLock>, // The handler for processing MCP messages handler: Box, // Information about the server @@ -68,12 +67,10 @@ impl ClientRuntime { transport: TransportType, handler: Box, ) -> Self { - let mut map: HashMap = HashMap::new(); - map.insert(DEFAULT_STREAM_ID.to_string(), transport); let (server_details_tx, server_details_rx) = watch::channel::>(None); Self { - transport_map: tokio::sync::RwLock::new(map), + transport_map: tokio::sync::RwLock::new(Some(transport)), handler, client_details, handlers: Mutex::new(vec![]), @@ -94,11 +91,10 @@ impl ClientRuntime { transport_options: StreamableTransportOptions, handler: Box, ) -> Self { - let map: HashMap = HashMap::new(); let (server_details_tx, server_details_rx) = watch::channel::>(None); Self { - transport_map: tokio::sync::RwLock::new(map), + transport_map: tokio::sync::RwLock::new(None), handler, client_details, handlers: Mutex::new(vec![]), @@ -113,8 +109,12 @@ impl ClientRuntime { } async fn initialize_request(self: Arc) -> SdkResult<()> { - let request = InitializeRequest::new(self.client_details.clone()); - let result: ServerResult = self.request(request.into(), None).await?.try_into()?; + let result: ServerResult = self + .request( + RequestFromClient::InitializeRequest(self.client_details.clone()), + None, + ) + .await?; if let ServerResult::InitializeResult(initialize_result) = result { ensure_server_protocole_compatibility( @@ -131,7 +131,7 @@ impl ClientRuntime { } // send a InitializedNotification to the server - self.send_notification(InitializedNotification::new(None).into()) + self.send_notification(NotificationFromClient::InitializedNotification(None)) .await?; } else { return Err(RpcError::invalid_params() @@ -149,9 +149,10 @@ impl ClientRuntime { ) -> SdkResult> { let response = match message { ServerMessage::Request(jsonrpc_request) => { + let request_id = jsonrpc_request.request_id().clone(); let result = self .handler - .handle_request(jsonrpc_request.request, self) + .handle_request(jsonrpc_request.into(), self) .await; // create a response to send back to the server @@ -160,12 +161,12 @@ impl ClientRuntime { Err(error_value) => MessageFromClient::Error(error_value), }; - let mcp_message = ClientMessage::from_message(response, Some(jsonrpc_request.id))?; + let mcp_message = ClientMessage::from_message(response, Some(request_id))?; Some(mcp_message) } ServerMessage::Notification(jsonrpc_notification) => { self.handler - .handle_notification(jsonrpc_notification.notification, self) + .handle_notification(jsonrpc_notification.into(), self) .await?; None } @@ -173,15 +174,17 @@ impl ClientRuntime { self.handler .handle_error(&jsonrpc_error.error, self) .await?; - if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { - tx_response - .send(ServerMessage::Error(jsonrpc_error)) - .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; - } else { - tracing::warn!( - "Received an error response with no corresponding request: {:?}", - &jsonrpc_error.id - ); + if let Some(request_id) = jsonrpc_error.id.as_ref() { + if let Some(tx_response) = transport.pending_request_tx(request_id).await { + tx_response + .send(ServerMessage::Error(jsonrpc_error)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received an error response with no corresponding request: {:?}", + &request_id + ); + } } None } @@ -205,7 +208,7 @@ impl ClientRuntime { async fn start_standalone(self: Arc) -> SdkResult<()> { let self_clone = self.clone(); let transport_map = self_clone.transport_map.read().await; - let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + let transport = transport_map.as_ref().ok_or( RpcError::internal_error() .with_message("transport stream does not exists or is closed!".to_string()), )?; @@ -319,19 +322,10 @@ impl ClientRuntime { ) -> SdkResult<()> { let mut transport_map = self.transport_map.write().await; tracing::trace!("save transport for stream id : {}", stream_id); - transport_map.insert(stream_id.to_string(), transport); + *transport_map = Some(transport); Ok(()) } - pub(crate) async fn transport_by_stream(&self, stream_id: &str) -> SdkResult { - let transport_map = self.transport_map.read().await; - transport_map.get(stream_id).cloned().ok_or_else(|| { - RpcError::internal_error() - .with_message(format!("Transport for key {stream_id} not found")) - .into() - }) - } - #[cfg(feature = "streamable-http")] pub(crate) async fn new_transport( &self, @@ -440,14 +434,18 @@ impl ClientRuntime { } }; - let transport = Arc::new(self.new_transport(session_id, false).await?); + let transport: Arc< + dyn TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + > = Arc::new(self.new_transport(session_id, false).await?); let mut stream = transport.start().await?; - self.store_transport(&stream_id, transport).await?; - - let transport = self.transport_by_stream(&stream_id).await?; //TODO: remove - let send_task = async { let result = transport.send_message(messages, timeout).await?; @@ -553,7 +551,7 @@ impl McpClient for ClientRuntime { let transport_map = self.transport_map.read().await; - let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + let transport = transport_map.as_ref().ok_or( RpcError::internal_error() .with_message("transport stream does not exists or is closed!".to_string()), )?; @@ -592,7 +590,7 @@ impl McpClient for ClientRuntime { } let transport_map = self.transport_map.read().await; - let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + let transport = transport_map.as_ref().ok_or( RpcError::internal_error() .with_message("transport stream does not exists or is closed!".to_string()), )?; @@ -642,9 +640,9 @@ impl McpClient for ClientRuntime { *is_shut_down_lock = true; let mut transport_map = self.transport_map.write().await; - let transports: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); + let transport_option = transport_map.take(); drop(transport_map); - for transport in transports { + if let Some(transport) = transport_option { let _ = transport.shut_down().await; } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index 06964ed..576d7a7 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -1,21 +1,17 @@ -use std::sync::Arc; - +use super::ClientRuntime; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, MessageFromClient, NotificationFromServer, RequestFromServer, ResultFromClient, ServerMessage, ServerMessages, }, - InitializeRequestParams, RpcError, ServerNotification, ServerRequest, + InitializeRequestParams, RpcError, }; +use crate::{error::SdkResult, mcp_client::ClientHandler, mcp_traits::McpClientHandler, McpClient}; use async_trait::async_trait; - #[cfg(feature = "streamable-http")] use rust_mcp_transport::StreamableTransportOptions; use rust_mcp_transport::TransportDispatcher; - -use crate::{error::SdkResult, mcp_client::ClientHandler, mcp_traits::McpClientHandler, McpClient}; - -use super::ClientRuntime; +use std::sync::Arc; /// Creates a new MCP client runtime with the specified configuration. /// @@ -91,33 +87,51 @@ impl McpClientHandler for ClientInternalHandler> { runtime: &dyn McpClient, ) -> std::result::Result { match server_jsonrpc_request { - RequestFromServer::ServerRequest(request) => match request { - ServerRequest::PingRequest(ping_request) => self - .handler - .handle_ping_request(ping_request, runtime) - .await - .map(|value| value.into()), - ServerRequest::CreateMessageRequest(create_message_request) => self - .handler - .handle_create_message_request(create_message_request, runtime) - .await - .map(|value| value.into()), - ServerRequest::ListRootsRequest(list_roots_request) => self - .handler - .handle_list_roots_request(list_roots_request, runtime) - .await - .map(|value| value.into()), - #[cfg(feature = "2025_06_18")] - ServerRequest::ElicitRequest(elicit_request) => self - .handler - .handle_elicit_request(elicit_request, runtime) - .await - .map(|value| value.into()), - }, - // Handles custom notifications received from the server by passing the request to self.handler - RequestFromServer::CustomRequest(custom_request) => self + RequestFromServer::PingRequest(params) => self + .handler + .handle_ping_request(params, runtime) + .await + .map(|value| value.into()), + RequestFromServer::CreateMessageRequest(params) => self + .handler + .handle_create_message_request(params, runtime) + .await + .map(|value| value.into()), + RequestFromServer::ListRootsRequest(params) => self + .handler + .handle_list_roots_request(params, runtime) + .await + .map(|value| value.into()), + RequestFromServer::ElicitRequest(params) => self + .handler + .handle_elicit_request(params, runtime) + .await + .map(|value| value.into()), + + RequestFromServer::GetTaskRequest(params) => self + .handler + .handle_get_task_request(params, runtime) + .await + .map(|value| value.into()), + RequestFromServer::GetTaskPayloadRequest(params) => self + .handler + .handle_get_task_payload_request(params, runtime) + .await + .map(|value| value.into()), + RequestFromServer::CancelTaskRequest(params) => self + .handler + .handle_cancel_task_request(params, runtime) + .await + .map(|value| value.into()), + RequestFromServer::ListTasksRequest(params) => self + .handler + .handle_list_tasks_request(params, runtime) + .await + .map(|value| value.into()), + + RequestFromServer::CustomRequest(params) => self .handler - .handle_custom_request(custom_request, runtime) + .handle_custom_request(params, runtime) .await .map(|value| value.into()), } @@ -140,70 +154,67 @@ impl McpClientHandler for ClientInternalHandler> { runtime: &dyn McpClient, ) -> SdkResult<()> { match server_jsonrpc_notification { - NotificationFromServer::ServerNotification(server_notification) => { - match server_notification { - ServerNotification::CancelledNotification(cancelled_notification) => { - self.handler - .handle_cancelled_notification(cancelled_notification, runtime) - .await?; - } - ServerNotification::ProgressNotification(progress_notification) => { - self.handler - .handle_progress_notification(progress_notification, runtime) - .await?; - } - ServerNotification::ResourceListChangedNotification( + NotificationFromServer::CancelledNotification(cancelled_notification) => { + self.handler + .handle_cancelled_notification(cancelled_notification, runtime) + .await?; + } + NotificationFromServer::ProgressNotification(progress_notification) => { + self.handler + .handle_progress_notification(progress_notification, runtime) + .await?; + } + NotificationFromServer::ResourceListChangedNotification( + resource_list_changed_notification, + ) => { + self.handler + .handle_resource_list_changed_notification( resource_list_changed_notification, - ) => { - self.handler - .handle_resource_list_changed_notification( - resource_list_changed_notification, - runtime, - ) - .await?; - } - ServerNotification::ResourceUpdatedNotification( - resource_updated_notification, - ) => { - self.handler - .handle_resource_updated_notification( - resource_updated_notification, - runtime, - ) - .await?; - } - ServerNotification::PromptListChangedNotification( + runtime, + ) + .await?; + } + NotificationFromServer::ResourceUpdatedNotification(resource_updated_notification) => { + self.handler + .handle_resource_updated_notification(resource_updated_notification, runtime) + .await?; + } + NotificationFromServer::PromptListChangedNotification( + prompt_list_changed_notification, + ) => { + self.handler + .handle_prompt_list_changed_notification( prompt_list_changed_notification, - ) => { - self.handler - .handle_prompt_list_changed_notification( - prompt_list_changed_notification, - runtime, - ) - .await?; - } - ServerNotification::ToolListChangedNotification( - tool_list_changed_notification, - ) => { - self.handler - .handle_tool_list_changed_notification( - tool_list_changed_notification, - runtime, - ) - .await?; - } - ServerNotification::LoggingMessageNotification( - logging_message_notification, - ) => { - self.handler - .handle_logging_message_notification( - logging_message_notification, - runtime, - ) - .await?; - } - } + runtime, + ) + .await?; + } + NotificationFromServer::ToolListChangedNotification(tool_list_changed_notification) => { + self.handler + .handle_tool_list_changed_notification(tool_list_changed_notification, runtime) + .await?; + } + NotificationFromServer::LoggingMessageNotification(logging_message_notification) => { + self.handler + .handle_logging_message_notification(logging_message_notification, runtime) + .await?; + } + NotificationFromServer::TaskStatusNotification(task_status_notification) => { + self.handler + .handle_task_status_notification(task_status_notification, runtime) + .await?; } + NotificationFromServer::ElicitationCompleteNotification( + elicitation_complete_notification, + ) => { + self.handler + .handle_elicitation_complete_notification( + elicitation_complete_notification, + runtime, + ) + .await?; + } + // Handles custom notifications received from the server by passing the request to self.handler NotificationFromServer::CustomNotification(custom_notification) => { self.handler diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index a429bae..98b6fe1 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -17,7 +17,6 @@ use futures::{StreamExt, TryFutureExt}; #[cfg(feature = "hyper-server")] use rust_mcp_transport::SessionId; use rust_mcp_transport::{IoStream, TransportDispatcher}; -use std::collections::HashMap; use std::panic; use std::sync::Arc; use std::time::Duration; @@ -46,7 +45,7 @@ pub struct ServerRuntime { server_details: Arc, #[cfg(feature = "hyper-server")] session_id: Option, - transport_map: tokio::sync::RwLock>, //TODO: remove the transport_map, we do not need a hashmap for it + transport_map: tokio::sync::RwLock>, request_id_gen: Box, client_details_tx: watch::Sender>, client_details_rx: watch::Receiver>, @@ -107,7 +106,7 @@ impl McpServer for ServerRuntime { request_timeout: Option, ) -> SdkResult> { let transport_map = self.transport_map.read().await; - let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + let transport = transport_map.as_ref().ok_or( RpcError::internal_error() .with_message("transport stream does not exists or is closed!".to_string()), )?; @@ -133,7 +132,7 @@ impl McpServer for ServerRuntime { request_timeout: Option, ) -> SdkResult>> { let transport_map = self.transport_map.read().await; - let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + let transport = transport_map.as_ref().ok_or( RpcError::internal_error() .with_message("transport stream does not exists or is closed!".to_string()), )?; @@ -160,7 +159,7 @@ impl McpServer for ServerRuntime { let self_clone = self.clone(); let transport_map = self_clone.transport_map.read().await; - let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + let transport = transport_map.as_ref().ok_or( RpcError::internal_error() .with_message("transport stream does not exists or is closed!".to_string()), )?; @@ -254,7 +253,7 @@ impl McpServer for ServerRuntime { async fn stderr_message(&self, message: String) -> SdkResult<()> { let transport_map = self.transport_map.read().await; - let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + let transport = transport_map.as_ref().ok_or( RpcError::internal_error() .with_message("transport stream does not exists or is closed!".to_string()), )?; @@ -275,14 +274,10 @@ impl McpServer for ServerRuntime { } impl ServerRuntime { - pub(crate) async fn consume_payload_string( - &self, - stream_id: &str, - payload: &str, - ) -> SdkResult<()> { + pub(crate) async fn consume_payload_string(&self, payload: &str) -> SdkResult<()> { let transport_map = self.transport_map.read().await; - let transport = transport_map.get(stream_id).ok_or( + let transport = transport_map.as_ref().ok_or( RpcError::internal_error() .with_message("stream id does not exists or is closed!".to_string()), )?; @@ -308,10 +303,14 @@ impl ServerRuntime { let response = match message { // Handle a client request ClientMessage::Request(client_jsonrpc_request) => { + let request_id = client_jsonrpc_request.request_id().clone(); + let result = self .handler - .handle_request(client_jsonrpc_request.request, self.clone()) + .handle_request(client_jsonrpc_request.into(), self.clone()) .await; + println!(">>> {:?} ", result); + // create a response to send back to the client let response: MessageFromServer = match result { Ok(success_value) => success_value.into(), @@ -326,13 +325,13 @@ impl ServerRuntime { }; let mpc_message: ServerMessage = - ServerMessage::from_message(response, Some(client_jsonrpc_request.id))?; + ServerMessage::from_message(response, Some(request_id))?; Some(mpc_message) } ClientMessage::Notification(client_jsonrpc_notification) => { self.handler - .handle_notification(client_jsonrpc_notification.notification, self.clone()) + .handle_notification(client_jsonrpc_notification.into(), self.clone()) .await?; None } @@ -340,15 +339,18 @@ impl ServerRuntime { self.handler .handle_error(&jsonrpc_error.error, self.clone()) .await?; - if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { - tx_response - .send(ClientMessage::Error(jsonrpc_error)) - .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; - } else { - tracing::warn!( - "Received an error response with no corresponding request {:?}", - &jsonrpc_error.id - ); + + if let Some(request_id) = jsonrpc_error.id.as_ref() { + if let Some(tx_response) = transport.pending_request_tx(request_id).await { + tx_response + .send(ClientMessage::Error(jsonrpc_error)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received an error response with no corresponding request {:?}", + &jsonrpc_error.id + ); + } } None } @@ -387,7 +389,7 @@ impl ServerRuntime { } let mut transport_map = self.transport_map.write().await; tracing::trace!("save transport for stream id : {}", stream_id); - transport_map.insert(stream_id.to_string(), transport); + *transport_map = Some(transport); Ok(()) } @@ -398,7 +400,7 @@ impl ServerRuntime { } let transport_map = self.transport_map.read().await; tracing::trace!("removing transport for stream id : {}", stream_id); - if let Some(transport) = transport_map.get(stream_id) { + if let Some(transport) = transport_map.as_ref() { transport.shut_down().await?; } // transport_map.remove(stream_id); @@ -407,16 +409,16 @@ impl ServerRuntime { pub(crate) async fn shutdown(&self) { let mut transport_map = self.transport_map.write().await; - let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); + let transport_option = transport_map.take(); drop(transport_map); - for item in items { - let _ = item.shut_down().await; + if let Some(transport) = transport_option { + let _ = transport.shut_down().await; } } - pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool { + pub(crate) async fn default_stream_exists(&self) -> bool { let transport_map = self.transport_map.read().await; - let live_transport = if let Some(t) = transport_map.get(stream_id) { + let live_transport = if let Some(t) = transport_map.as_ref() { !t.is_shut_down().await } else { false @@ -581,7 +583,7 @@ impl ServerRuntime { server_details, handler, session_id: Some(session_id), - transport_map: tokio::sync::RwLock::new(HashMap::new()), + transport_map: tokio::sync::RwLock::new(None), client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), @@ -600,8 +602,6 @@ impl ServerRuntime { >, handler: Arc, ) -> Arc { - let mut map: HashMap = HashMap::new(); - map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport)); let (client_details_tx, client_details_rx) = watch::channel::>(None); Arc::new(Self { @@ -609,7 +609,7 @@ impl ServerRuntime { handler, #[cfg(feature = "hyper-server")] session_id: None, - transport_map: tokio::sync::RwLock::new(map), + transport_map: tokio::sync::RwLock::new(Some(Arc::new(transport))), client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index c4eeb81..253ecd6 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -1,5 +1,4 @@ -use std::sync::Arc; - +use super::ServerRuntime; #[cfg(feature = "hyper-server")] use crate::auth::AuthInfo; use crate::schema::{ @@ -7,21 +6,18 @@ use crate::schema::{ self, CallToolError, ClientMessage, ClientMessages, MessageFromServer, NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, ServerMessages, }, - CallToolResult, ClientNotification, ClientRequest, InitializeResult, RpcError, + CallToolResult, InitializeResult, RpcError, }; -use async_trait::async_trait; - -use rust_mcp_transport::TransportDispatcher; - -use super::ServerRuntime; -#[cfg(feature = "hyper-server")] -use rust_mcp_transport::SessionId; - use crate::{ error::SdkResult, mcp_handlers::mcp_server_handler::ServerHandler, mcp_traits::{McpServer, McpServerHandler}, }; +use async_trait::async_trait; +#[cfg(feature = "hyper-server")] +use rust_mcp_transport::SessionId; +use rust_mcp_transport::TransportDispatcher; +use std::sync::Arc; /// Creates a new MCP server runtime with the specified configuration. /// @@ -50,13 +46,9 @@ pub fn create_server( ServerMessages, ServerMessage, >, - handler: impl ServerHandler, + handler: Arc, ) -> Arc { - ServerRuntime::new( - server_details, - transport, - Arc::new(ServerRuntimeInternalHandler::new(Box::new(handler))), - ) + ServerRuntime::new(server_details, transport, handler) } #[cfg(feature = "hyper-server")] @@ -86,93 +78,108 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { runtime: Arc, ) -> std::result::Result { match client_jsonrpc_request { - schema_utils::RequestFromClient::ClientRequest(client_request) => { - match client_request { - ClientRequest::InitializeRequest(initialize_request) => self - .handler - .handle_initialize_request(initialize_request, runtime) - .await - .map(|value| value.into()), - ClientRequest::PingRequest(ping_request) => self - .handler - .handle_ping_request(ping_request, runtime) - .await - .map(|value| value.into()), - ClientRequest::ListResourcesRequest(list_resources_request) => self - .handler - .handle_list_resources_request(list_resources_request, runtime) - .await - .map(|value| value.into()), - ClientRequest::ListResourceTemplatesRequest( + RequestFromClient::InitializeRequest(initialize_request) => self + .handler + .handle_initialize_request(initialize_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::PingRequest(ping_request) => self + .handler + .handle_ping_request(ping_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::ListResourcesRequest(list_resources_request) => self + .handler + .handle_list_resources_request(list_resources_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::ListResourceTemplatesRequest(list_resource_templates_request) => { + self.handler + .handle_list_resource_templates_request( list_resource_templates_request, - ) => self - .handler - .handle_list_resource_templates_request( - list_resource_templates_request, - runtime, - ) - .await - .map(|value| value.into()), - ClientRequest::ReadResourceRequest(read_resource_request) => self - .handler - .handle_read_resource_request(read_resource_request, runtime) - .await - .map(|value| value.into()), - ClientRequest::SubscribeRequest(subscribe_request) => self - .handler - .handle_subscribe_request(subscribe_request, runtime) - .await - .map(|value| value.into()), - ClientRequest::UnsubscribeRequest(unsubscribe_request) => self - .handler - .handle_unsubscribe_request(unsubscribe_request, runtime) - .await - .map(|value| value.into()), - ClientRequest::ListPromptsRequest(list_prompts_request) => self - .handler - .handle_list_prompts_request(list_prompts_request, runtime) - .await - .map(|value| value.into()), + runtime, + ) + .await + .map(|value| value.into()) + } + RequestFromClient::ReadResourceRequest(read_resource_request) => self + .handler + .handle_read_resource_request(read_resource_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::SubscribeRequest(subscribe_request) => self + .handler + .handle_subscribe_request(subscribe_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::UnsubscribeRequest(unsubscribe_request) => self + .handler + .handle_unsubscribe_request(unsubscribe_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::ListPromptsRequest(list_prompts_request) => self + .handler + .handle_list_prompts_request(list_prompts_request, runtime) + .await + .map(|value| value.into()), - ClientRequest::GetPromptRequest(prompt_request) => self - .handler - .handle_get_prompt_request(prompt_request, runtime) - .await - .map(|value| value.into()), - ClientRequest::ListToolsRequest(list_tools_request) => self - .handler - .handle_list_tools_request(list_tools_request, runtime) - .await - .map(|value| value.into()), - ClientRequest::CallToolRequest(call_tool_request) => { - let result = self - .handler - .handle_call_tool_request(call_tool_request, runtime) - .await; + RequestFromClient::GetPromptRequest(prompt_request) => self + .handler + .handle_get_prompt_request(prompt_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::ListToolsRequest(list_tools_request) => self + .handler + .handle_list_tools_request(list_tools_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::CallToolRequest(call_tool_request) => { + let result = self + .handler + .handle_call_tool_request(call_tool_request, runtime) + .await; - Ok(result.map_or_else( - |err| { - let result: CallToolResult = CallToolError::new(err).into(); - result.into() - }, - Into::into, - )) - } - ClientRequest::SetLevelRequest(set_level_request) => self - .handler - .handle_set_level_request(set_level_request, runtime) - .await - .map(|value| value.into()), - ClientRequest::CompleteRequest(complete_request) => self - .handler - .handle_complete_request(complete_request, runtime) - .await - .map(|value| value.into()), - } + Ok(result.map_or_else( + |err| { + let result: CallToolResult = CallToolError::new(err).into(); + result.into() + }, + Into::into, + )) } - schema_utils::RequestFromClient::CustomRequest(value) => self + RequestFromClient::SetLevelRequest(set_level_request) => self + .handler + .handle_set_level_request(set_level_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::CompleteRequest(complete_request) => self + .handler + .handle_complete_request(complete_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::GetTaskRequest(get_task_request) => self .handler - .handle_custom_request(value, runtime) + .handle_get_task_request(get_task_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::GetTaskPayloadRequest(get_task_payload_request) => self + .handler + .handle_get_task_payload_request(get_task_payload_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::CancelTaskRequest(cancel_task_request) => self + .handler + .handle_cancel_task_request(cancel_task_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::ListTasksRequest(list_tasks_request) => self + .handler + .handle_list_task_request(list_tasks_request, runtime) + .await + .map(|value| value.into()), + RequestFromClient::CustomRequest(custom_request) => self + .handler + .handle_custom_request(custom_request, runtime) .await .map(|value| value.into()), } @@ -193,39 +200,38 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { runtime: Arc, ) -> SdkResult<()> { match client_jsonrpc_notification { - schema_utils::NotificationFromClient::ClientNotification(client_notification) => { - match client_notification { - ClientNotification::CancelledNotification(cancelled_notification) => { - self.handler - .handle_cancelled_notification(cancelled_notification, runtime) - .await?; - } - ClientNotification::InitializedNotification(initialized_notification) => { - self.handler - .handle_initialized_notification( - initialized_notification, - runtime.clone(), - ) - .await?; - self.handler.on_initialized(runtime).await; - } - ClientNotification::ProgressNotification(progress_notification) => { - self.handler - .handle_progress_notification(progress_notification, runtime) - .await?; - } - ClientNotification::RootsListChangedNotification( + NotificationFromClient::CancelledNotification(cancelled_notification) => { + self.handler + .handle_cancelled_notification(cancelled_notification, runtime) + .await?; + } + NotificationFromClient::InitializedNotification(initialized_notification) => { + self.handler + .handle_initialized_notification(initialized_notification, runtime.clone()) + .await?; + self.handler.on_initialized(runtime).await; + } + NotificationFromClient::ProgressNotification(progress_notification) => { + self.handler + .handle_progress_notification(progress_notification, runtime) + .await?; + } + NotificationFromClient::RootsListChangedNotification( + roots_list_changed_notification, + ) => { + self.handler + .handle_roots_list_changed_notification( roots_list_changed_notification, - ) => { - self.handler - .handle_roots_list_changed_notification( - roots_list_changed_notification, - runtime, - ) - .await?; - } - } + runtime, + ) + .await?; } + NotificationFromClient::TaskStatusNotification(task_status_notification) => { + self.handler + .handle_task_status_notification(task_status_notification, runtime) + .await?; + } + schema_utils::NotificationFromClient::CustomNotification(value) => { self.handler.handle_custom_notification(value).await?; } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index c617cea..f0ecd4a 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -3,12 +3,12 @@ use crate::error::SdkResult; use crate::mcp_handlers::mcp_server_handler_core::ServerHandlerCore; use crate::mcp_traits::{McpServer, McpServerHandler}; use crate::schema::schema_utils::{ - self, ClientMessage, MessageFromServer, NotificationFromClient, RequestFromClient, - ResultFromServer, ServerMessage, + ClientMessage, MessageFromServer, NotificationFromClient, RequestFromClient, ResultFromServer, + ServerMessage, }; use crate::schema::{ schema_utils::{ClientMessages, ServerMessages}, - ClientRequest, InitializeResult, RpcError, + InitializeResult, RpcError, }; use async_trait::async_trait; use rust_mcp_transport::TransportDispatcher; @@ -41,13 +41,9 @@ pub fn create_server( ServerMessages, ServerMessage, >, - handler: impl ServerHandlerCore, + handler: Arc, ) -> Arc { - ServerRuntime::new( - server_details, - transport, - Arc::new(RuntimeCoreInternalHandler::new(Box::new(handler))), - ) + ServerRuntime::new(server_details, transport, handler) } pub(crate) struct RuntimeCoreInternalHandler { @@ -68,13 +64,10 @@ impl McpServerHandler for RuntimeCoreInternalHandler> runtime: Arc, ) -> std::result::Result { // store the client details if the request is a client initialization request - if let schema_utils::RequestFromClient::ClientRequest(ClientRequest::InitializeRequest( - initialize_request, - )) = &client_jsonrpc_request - { + if let RequestFromClient::InitializeRequest(initialize_request) = &client_jsonrpc_request { // keep a copy of the InitializeRequestParams which includes client_info and capabilities runtime - .set_client_details(initialize_request.params.clone()) + .set_client_details(initialize_request.clone()) .await .map_err(|err| RpcError::internal_error().with_message(format!("{err}")))?; } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index c295082..e412acb 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -3,18 +3,21 @@ use crate::schema::{ ClientMessage, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, }, - CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams, - CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation, - InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams, - ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest, - ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams, - LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId, - RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities, - SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams, - UnsubscribeRequest, UnsubscribeRequestParams, + CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequestParams, + CreateMessageRequest, GenericResult, GetPromptRequest, GetPromptRequestParams, Implementation, + InitializeRequestParams, InitializeResult, ListPromptsRequest, ListResourceTemplatesRequest, + ListResourcesRequest, ListRootsRequest, ListToolsRequest, NotificationParams, + PaginatedRequestParams, ReadResourceRequest, ReadResourceRequestParams, RequestId, + RequestParams, RootsListChangedNotification, RpcError, ServerCapabilities, SetLevelRequest, + SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams, UnsubscribeRequest, + UnsubscribeRequestParams, }; use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; +use rust_mcp_schema::{ + schema_utils::CustomNotification, CancelledNotificationParams, ProgressNotificationParams, + TaskStatusNotificationParams, +}; use std::{sync::Arc, time::Duration}; #[async_trait] @@ -151,6 +154,123 @@ pub trait McpClient: Sync + Send { self.server_info()?.instructions } + /// Asserts that server capabilities support the requested method. + /// + /// Verifies that the server has the necessary capabilities to handle the given request method. + /// If the server is not initialized or lacks a required capability, an error is returned. + /// This can be utilized to avoid sending requests when the opposing party lacks support for them. + fn assert_server_capabilities(&self, request_method: &str) -> SdkResult<()> { + let entity = "Server"; + + let capabilities = self.server_capabilities().ok_or::( + RpcError::internal_error().with_message("Server is not initialized!".to_string()), + )?; + + if request_method == SetLevelRequest::method_value() && capabilities.logging.is_none() { + return Err(RpcError::internal_error() + .with_message(format_assertion_message(entity, "logging", request_method)) + .into()); + } + + if [ + GetPromptRequest::method_value(), + ListPromptsRequest::method_value(), + ] + .contains(&request_method) + && capabilities.prompts.is_none() + { + return Err(RpcError::internal_error() + .with_message(format_assertion_message(entity, "prompts", request_method)) + .into()); + } + + if [ + ListResourcesRequest::method_value(), + ListResourceTemplatesRequest::method_value(), + ReadResourceRequest::method_value(), + SubscribeRequest::method_value(), + UnsubscribeRequest::method_value(), + ] + .contains(&request_method) + && capabilities.resources.is_none() + { + return Err(RpcError::internal_error() + .with_message(format_assertion_message( + entity, + "resources", + request_method, + )) + .into()); + } + + if [ + CallToolRequest::method_value(), + ListToolsRequest::method_value(), + ] + .contains(&request_method) + && capabilities.tools.is_none() + { + return Err(RpcError::internal_error() + .with_message(format_assertion_message(entity, "tools", request_method)) + .into()); + } + + Ok(()) + } + + fn assert_client_notification_capabilities( + &self, + notification_method: &str, + ) -> std::result::Result<(), RpcError> { + let entity = "Client"; + let capabilities = &self.client_info().capabilities; + + if notification_method == RootsListChangedNotification::method_value() + && capabilities.roots.is_some() + { + return Err( + RpcError::internal_error().with_message(format_assertion_message( + entity, + "roots list changed notifications", + notification_method, + )), + ); + } + + Ok(()) + } + + fn assert_client_request_capabilities( + &self, + request_method: &str, + ) -> std::result::Result<(), RpcError> { + let entity = "Client"; + let capabilities = &self.client_info().capabilities; + + if request_method == CreateMessageRequest::method_value() && capabilities.sampling.is_some() + { + return Err( + RpcError::internal_error().with_message(format_assertion_message( + entity, + "sampling capability", + request_method, + )), + ); + } + + if request_method == ListRootsRequest::method_value() && capabilities.roots.is_some() { + return Err( + RpcError::internal_error().with_message(format_assertion_message( + entity, + "roots capability", + request_method, + )), + ); + } + + Ok(()) + } + /// Sends a request to the server and processes the response. /// /// This function sends a `RequestFromClient` message to the server, waits for the response, @@ -198,6 +318,10 @@ pub trait McpClient: Sync + Send { Ok(()) } + /******************* + Requests + *******************/ + /// A ping request to check that the other party is still alive. /// The receiver must promptly respond, or else may be disconnected. /// @@ -208,232 +332,315 @@ pub trait McpClient: Sync + Send { /// # Returns /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful. /// If the request or conversion fails, an error is returned. - async fn ping(&self, timeout: Option) -> SdkResult { - let ping_request = PingRequest::new(None); - let response = self.request(ping_request.into(), timeout).await?; + async fn ping( + &self, + params: Option, + timeout: Option, + ) -> SdkResult { + let response = self + .request(RequestFromClient::PingRequest(params), timeout) + .await?; Ok(response.try_into()?) } - async fn complete( + ///send a request from the client to the server, to ask for completion options. + async fn request_completion( &self, params: CompleteRequestParams, ) -> SdkResult { - let request = CompleteRequest::new(params); - let response = self.request(request.into(), None).await?; + let response = self + .request(RequestFromClient::CompleteRequest(params), None) + .await?; Ok(response.try_into()?) } - async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult { - let request = SetLevelRequest::new(SetLevelRequestParams { level }); - let response = self.request(request.into(), None).await?; + /// send a request from the client to the server, to enable or adjust logging. + async fn request_set_logging_level( + &self, + params: SetLevelRequestParams, + ) -> SdkResult { + let response = self + .request(RequestFromClient::SetLevelRequest(params), None) + .await?; Ok(response.try_into()?) } - async fn get_prompt( + /// send a request to get a prompt provided by the server. + async fn request_prompt( &self, params: GetPromptRequestParams, ) -> SdkResult { - let request = GetPromptRequest::new(params); - let response = self.request(request.into(), None).await?; + let response = self + .request(RequestFromClient::GetPromptRequest(params), None) + .await?; Ok(response.try_into()?) } - async fn list_prompts( + ///Request a list of prompts and prompt templates the server has. + async fn request_prompt_list( &self, - params: Option, + params: Option, ) -> SdkResult { - let request = ListPromptsRequest::new(params); - let response = self.request(request.into(), None).await?; + let response = self + .request(RequestFromClient::ListPromptsRequest(params), None) + .await?; Ok(response.try_into()?) } - async fn list_resources( + /// request a list of resources the server has. + async fn request_resource_list( &self, - params: Option, + params: Option, ) -> SdkResult { - // passing ListResourcesRequestParams::default() if params is None - // need to investigate more but this could be a inconsistency on some MCP servers - // where it is not required for other requests like prompts/list or tools/list etc - // that excepts an empty params to be passed (like server-everything) - let request = - ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default()))); - let response = self.request(request.into(), None).await?; + let response = self + .request(RequestFromClient::ListResourcesRequest(params), None) + .await?; Ok(response.try_into()?) } - async fn list_resource_templates( + /// request a list of resource templates the server has. + async fn request_resource_template_list( &self, - params: Option, + params: Option, ) -> SdkResult { - let request = ListResourceTemplatesRequest::new(params); - let response = self.request(request.into(), None).await?; + let response = self + .request( + RequestFromClient::ListResourceTemplatesRequest(params), + None, + ) + .await?; Ok(response.try_into()?) } - async fn read_resource( + /// send a request to the server to to read a specific resource URI. + async fn request_resource_read( &self, params: ReadResourceRequestParams, ) -> SdkResult { - let request = ReadResourceRequest::new(params); - let response = self.request(request.into(), None).await?; + let response = self + .request(RequestFromClient::ReadResourceRequest(params), None) + .await?; Ok(response.try_into()?) } - async fn subscribe_resource( + /// request resources/updated notifications from the server whenever a particular resource changes. + async fn request_resource_subscription( &self, params: SubscribeRequestParams, ) -> SdkResult { - let request = SubscribeRequest::new(params); - let response = self.request(request.into(), None).await?; + let response = self + .request(RequestFromClient::SubscribeRequest(params), None) + .await?; Ok(response.try_into()?) } - async fn unsubscribe_resource( + /// request cancellation of resources/updated notifications from the server. + /// This should follow a previous resources/subscribe request. + async fn request_resource_unsubscription( &self, params: UnsubscribeRequestParams, ) -> SdkResult { - let request = UnsubscribeRequest::new(params); - let response = self.request(request.into(), None).await?; + let response = self + .request(RequestFromClient::UnsubscribeRequest(params), None) + .await?; Ok(response.try_into()?) } - async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult { - let request = CallToolRequest::new(params); - let response = self.request(request.into(), None).await?; + /// invoke a tool provided by the server. + async fn request_tool_call(&self, params: CallToolRequestParams) -> SdkResult { + let response = self + .request(RequestFromClient::CallToolRequest(params), None) + .await?; Ok(response.try_into()?) } - async fn list_tools( + /// request a list of tools the server has. + async fn request_tool_list( &self, - params: Option, + params: Option, ) -> SdkResult { - let request = ListToolsRequest::new(params); - let response = self.request(request.into(), None).await?; + let response = self + .request(RequestFromClient::ListToolsRequest(params), None) + .await?; Ok(response.try_into()?) } - async fn send_roots_list_changed( - &self, - params: Option, - ) -> SdkResult<()> { - let notification = RootsListChangedNotification::new(params); - self.send_notification(notification.into()).await + /******************* + Notifications + *******************/ + + /// A notification from the client to the server, informing it that the list of roots has changed. + /// This notification should be sent whenever the client adds, removes, or modifies any root. + async fn notify_roots_list_changed(&self, params: Option) -> SdkResult<()> { + self.send_notification(NotificationFromClient::RootsListChangedNotification(params)) + .await } - /// Asserts that server capabilities support the requested method. - /// - /// Verifies that the server has the necessary capabilities to handle the given request method. - /// If the server is not initialized or lacks a required capability, an error is returned. - /// This can be utilized to avoid sending requests when the opposing party lacks support for them. - fn assert_server_capabilities(&self, request_method: &String) -> SdkResult<()> { - let entity = "Server"; + /// This notification can be sent by either side to indicate that it is cancelling a previously-issued request. + /// The request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification MAY arrive after the request has already finished. + /// This notification indicates that the result will be unused, so any associated processing SHOULD cease. + /// A client MUST NOT attempt to cancel its initialize request. + /// For task cancellation, use the tasks/cancel request instead of this notification + async fn notify_cancellation(&self, params: CancelledNotificationParams) -> SdkResult<()> { + self.send_notification(NotificationFromClient::CancelledNotification(params)) + .await + } - let capabilities = self.server_capabilities().ok_or::( - RpcError::internal_error().with_message("Server is not initialized!".to_string()), - )?; + ///Send an out-of-band notification used to inform the receiver of a progress update for a long-running request. + async fn notify_progress(&self, params: ProgressNotificationParams) -> SdkResult<()> { + self.send_notification(NotificationFromClient::ProgressNotification(params)) + .await + } - if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() { - return Err(RpcError::internal_error() - .with_message(format_assertion_message(entity, "logging", request_method)) - .into()); - } + /// Send an optional notification from the receiver to the requestor, informing them that a task's status has changed. + /// Receivers are not required to send these notifications. + async fn notify_task_status(&self, params: TaskStatusNotificationParams) -> SdkResult<()> { + self.send_notification(NotificationFromClient::TaskStatusNotification(params)) + .await + } - if [ - GetPromptRequest::method_name(), - ListPromptsRequest::method_name(), - ] - .contains(request_method) - && capabilities.prompts.is_none() - { - return Err(RpcError::internal_error() - .with_message(format_assertion_message(entity, "prompts", request_method)) - .into()); - } + ///Send a custom notification + async fn notify_custom(&self, params: CustomNotification) -> SdkResult<()> { + self.send_notification(NotificationFromClient::CustomNotification(params)) + .await + } - if [ - ListResourcesRequest::method_name(), - ListResourceTemplatesRequest::method_name(), - ReadResourceRequest::method_name(), - SubscribeRequest::method_name(), - UnsubscribeRequest::method_name(), - ] - .contains(request_method) - && capabilities.resources.is_none() - { - return Err(RpcError::internal_error() - .with_message(format_assertion_message( - entity, - "resources", - request_method, - )) - .into()); - } + /******************* + Deprecated + *******************/ + #[deprecated(since = "0.8.0", note = "Use `request_completion()` instead.")] + async fn complete( + &self, + params: CompleteRequestParams, + ) -> SdkResult { + let response = self + .request(RequestFromClient::CompleteRequest(params), None) + .await?; + Ok(response.try_into()?) + } - if [ - CallToolRequest::method_name(), - ListToolsRequest::method_name(), - ] - .contains(request_method) - && capabilities.tools.is_none() - { - return Err(RpcError::internal_error() - .with_message(format_assertion_message(entity, "tools", request_method)) - .into()); - } + #[deprecated(since = "0.8.0", note = "Use `request_set_logging_level()` instead.")] + async fn set_logging_level( + &self, + params: SetLevelRequestParams, + ) -> SdkResult { + let response = self + .request(RequestFromClient::SetLevelRequest(params), None) + .await?; + Ok(response.try_into()?) + } - Ok(()) + #[deprecated(since = "0.8.0", note = "Use `request_prompt()` instead.")] + async fn get_prompt( + &self, + params: GetPromptRequestParams, + ) -> SdkResult { + let response = self + .request(RequestFromClient::GetPromptRequest(params), None) + .await?; + Ok(response.try_into()?) } - fn assert_client_notification_capabilities( + #[deprecated(since = "0.8.0", note = "Use `request_prompt_list()` instead.")] + async fn list_prompts( &self, - notification_method: &String, - ) -> std::result::Result<(), RpcError> { - let entity = "Client"; - let capabilities = &self.client_info().capabilities; + params: Option, + ) -> SdkResult { + let response = self + .request(RequestFromClient::ListPromptsRequest(params), None) + .await?; + Ok(response.try_into()?) + } - if *notification_method == RootsListChangedNotification::method_name() - && capabilities.roots.is_some() - { - return Err( - RpcError::internal_error().with_message(format_assertion_message( - entity, - "roots list changed notifications", - notification_method, - )), - ); - } + #[deprecated(since = "0.8.0", note = "Use `request_resource_list()` instead.")] + async fn list_resources( + &self, + params: Option, + ) -> SdkResult { + let response = self + .request(RequestFromClient::ListResourcesRequest(params), None) + .await?; + Ok(response.try_into()?) + } - Ok(()) + #[deprecated( + since = "0.8.0", + note = "Use `request_resource_template_list()` instead." + )] + async fn list_resource_templates( + &self, + params: Option, + ) -> SdkResult { + let response = self + .request( + RequestFromClient::ListResourceTemplatesRequest(params), + None, + ) + .await?; + Ok(response.try_into()?) } - fn assert_client_request_capabilities( + #[deprecated(since = "0.8.0", note = "Use `request_resource_read()` instead.")] + async fn read_resource( &self, - request_method: &String, - ) -> std::result::Result<(), RpcError> { - let entity = "Client"; - let capabilities = &self.client_info().capabilities; + params: ReadResourceRequestParams, + ) -> SdkResult { + let response = self + .request(RequestFromClient::ReadResourceRequest(params), None) + .await?; + Ok(response.try_into()?) + } - if *request_method == CreateMessageRequest::method_name() && capabilities.sampling.is_some() - { - return Err( - RpcError::internal_error().with_message(format_assertion_message( - entity, - "sampling capability", - request_method, - )), - ); - } + #[deprecated( + since = "0.8.0", + note = "Use `request_resource_subscription()` instead." + )] + async fn subscribe_resource( + &self, + params: SubscribeRequestParams, + ) -> SdkResult { + let response = self + .request(RequestFromClient::SubscribeRequest(params), None) + .await?; + Ok(response.try_into()?) + } - if *request_method == ListRootsRequest::method_name() && capabilities.roots.is_some() { - return Err( - RpcError::internal_error().with_message(format_assertion_message( - entity, - "roots capability", - request_method, - )), - ); - } + #[deprecated( + since = "0.8.0", + note = "Use `request_resource_unsubscription()` instead." + )] + async fn unsubscribe_resource( + &self, + params: UnsubscribeRequestParams, + ) -> SdkResult { + let response = self + .request(RequestFromClient::UnsubscribeRequest(params), None) + .await?; + Ok(response.try_into()?) + } - Ok(()) + #[deprecated(since = "0.8.0", note = "Use `request_tool_call()` instead.")] + async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult { + let response = self + .request(RequestFromClient::CallToolRequest(params), None) + .await?; + Ok(response.try_into()?) + } + + #[deprecated(since = "0.8.0", note = "Use `request_tool_list()` instead.")] + async fn list_tools( + &self, + params: Option, + ) -> SdkResult { + let response = self + .request(RequestFromClient::ListToolsRequest(params), None) + .await?; + Ok(response.try_into()?) + } + + #[deprecated(since = "0.8.0", note = "Use `notify_roots_list_changed()` instead.")] + async fn send_roots_list_changed(&self, params: Option) -> SdkResult<()> { + self.send_notification(NotificationFromClient::RootsListChangedNotification(params)) + .await } } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs index cb37f2a..07ae9bc 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs @@ -60,3 +60,15 @@ pub trait McpClientHandler: Send + Sync { runtime: &dyn McpClient, ) -> SdkResult<()>; } + +// Custom trait for converting ServerHandler +#[cfg(feature = "server")] +pub trait ToMcpServerHandler { + fn to_mcp_server_handler(self) -> Arc; +} + +// Custom trait for converting ServerHandlerCore +#[cfg(feature = "server")] +pub trait ToMcpServerHandlerCore { + fn to_mcp_server_handler(self) -> Arc; +} diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index 43e04f1..30ee184 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -1,28 +1,30 @@ use crate::auth::AuthInfo; -use crate::{error::SdkResult, utils::format_assertion_message}; - use crate::schema::{ schema_utils::{ ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer, ResultFromClient, ServerMessage, }, CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult, - ElicitRequest, ElicitRequestParams, ElicitRequestedSchema, ElicitResult, GetPromptRequest, - Implementation, InitializeRequestParams, InitializeResult, ListPromptsRequest, - ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest, ListRootsRequestParams, - ListRootsResult, ListToolsRequest, LoggingMessageNotification, - LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification, - PromptListChangedNotificationParams, ReadResourceRequest, RequestId, - ResourceListChangedNotification, ResourceListChangedNotificationParams, - ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, - SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams, + ElicitRequestParams, ElicitResult, GetPromptRequest, Implementation, InitializeRequestParams, + InitializeResult, ListPromptsRequest, ListResourceTemplatesRequest, ListResourcesRequest, + ListRootsRequest, ListRootsResult, ListToolsRequest, LoggingMessageNotification, + LoggingMessageNotificationParams, NotificationParams, PromptListChangedNotification, + ReadResourceRequest, RequestId, RequestParams, ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, + ToolListChangedNotification, }; +use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; +use rust_mcp_schema::schema_utils::{CustomNotification, CustomRequest}; +use rust_mcp_schema::{ + CancelTaskParams, CancelTaskResult, CancelledNotificationParams, ElicitCompleteParams, + GenericResult, GetTaskParams, GetTaskPayloadParams, GetTaskResult, ProgressNotificationParams, + TaskStatusNotificationParams, +}; use rust_mcp_transport::SessionId; use std::{sync::Arc, time::Duration}; use tokio::sync::RwLockReadGuard; -//TODO: support options , such as enforceStrictCapabilities #[async_trait] pub trait McpServer: Sync + Send { async fn start(self: Arc) -> SdkResult<()>; @@ -36,191 +38,6 @@ pub trait McpServer: Sync + Send { async fn wait_for_initialization(&self); - async fn send( - &self, - message: MessageFromServer, - request_id: Option, - request_timeout: Option, - ) -> SdkResult>; - - async fn send_batch( - &self, - messages: Vec, - request_timeout: Option, - ) -> SdkResult>>; - - /// Checks whether the server has been initialized with client - fn is_initialized(&self) -> bool { - self.client_info().is_some() - } - - /// Returns the client's name and version information once initialization is complete. - /// This method retrieves the client details, if available, after successful initialization. - fn client_version(&self) -> Option { - self.client_info() - .map(|client_details| client_details.client_info) - } - - /// Returns the server's capabilities. - fn capabilities(&self) -> &ServerCapabilities { - &self.server_info().capabilities - } - - /// Sends an elicitation request to the client to prompt user input and returns the received response. - /// - /// The requested_schema argument allows servers to define the structure of the expected response using a restricted subset of JSON Schema. - /// To simplify client user experience, elicitation schemas are limited to flat objects with primitive properties only - async fn elicit_input( - &self, - message: String, - requested_schema: ElicitRequestedSchema, - ) -> SdkResult { - let request: ElicitRequest = ElicitRequest::new(ElicitRequestParams { - message, - requested_schema, - }); - let response = self.request(request.into(), None).await?; - ElicitResult::try_from(response).map_err(|err| err.into()) - } - - /// Sends a request to the client and processes the response. - /// - /// This function sends a `RequestFromServer` message to the client, waits for the response, - /// and handles the result. If the response is empty or of an invalid type, an error is returned. - /// Otherwise, it returns the result from the client. - async fn request( - &self, - request: RequestFromServer, - timeout: Option, - ) -> SdkResult { - // Send the request and receive the response. - let response = self - .send(MessageFromServer::RequestFromServer(request), None, timeout) - .await?; - - let client_message = response.ok_or_else(|| { - RpcError::internal_error() - .with_message("An empty response was received from the client.".to_string()) - })?; - - if client_message.is_error() { - return Err(client_message.as_error()?.error.into()); - } - - return Ok(client_message.as_response()?.result); - } - - /// Sends a notification. This is a one-way message that is not expected - /// to return any response. The method asynchronously sends the notification using - /// the transport layer and does not wait for any acknowledgement or result. - async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> { - self.send( - MessageFromServer::NotificationFromServer(notification), - None, - None, - ) - .await?; - Ok(()) - } - - /// Request a list of root URIs from the client. Roots allow - /// servers to ask for specific directories or files to operate on. A common example - /// for roots is providing a set of repositories or directories a server should operate on. - /// This request is typically used when the server needs to understand the file system - /// structure or access specific locations that the client has permission to read from - async fn list_roots( - &self, - params: Option, - ) -> SdkResult { - let request: ListRootsRequest = ListRootsRequest::new(params); - let response = self.request(request.into(), None).await?; - ListRootsResult::try_from(response).map_err(|err| err.into()) - } - - /// Send log message notification from server to client. - /// If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically. - async fn send_logging_message( - &self, - params: LoggingMessageNotificationParams, - ) -> SdkResult<()> { - let notification = LoggingMessageNotification::new(params); - self.send_notification(notification.into()).await - } - - /// An optional notification from the server to the client, informing it that - /// the list of prompts it offers has changed. - /// This may be issued by servers without any previous subscription from the client. - async fn send_prompt_list_changed( - &self, - params: Option, - ) -> SdkResult<()> { - let notification = PromptListChangedNotification::new(params); - self.send_notification(notification.into()).await - } - - /// An optional notification from the server to the client, - /// informing it that the list of resources it can read from has changed. - /// This may be issued by servers without any previous subscription from the client. - async fn send_resource_list_changed( - &self, - params: Option, - ) -> SdkResult<()> { - let notification = ResourceListChangedNotification::new(params); - self.send_notification(notification.into()).await - } - - /// A notification from the server to the client, informing it that - /// a resource has changed and may need to be read again. - /// This should only be sent if the client previously sent a resources/subscribe request. - async fn send_resource_updated( - &self, - params: ResourceUpdatedNotificationParams, - ) -> SdkResult<()> { - let notification = ResourceUpdatedNotification::new(params); - self.send_notification(notification.into()).await - } - - /// An optional notification from the server to the client, informing it that - /// the list of tools it offers has changed. - /// This may be issued by servers without any previous subscription from the client. - async fn send_tool_list_changed( - &self, - params: Option, - ) -> SdkResult<()> { - let notification = ToolListChangedNotification::new(params); - self.send_notification(notification.into()).await - } - - /// A ping request to check that the other party is still alive. - /// The receiver must promptly respond, or else may be disconnected. - /// - /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response - /// Once the response is received, it attempts to convert it into the expected - /// result type. - /// - /// # Returns - /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful. - /// If the request or conversion fails, an error is returned. - async fn ping(&self, timeout: Option) -> SdkResult { - let ping_request = PingRequest::new(None); - let response = self.request(ping_request.into(), timeout).await?; - Ok(response.try_into()?) - } - - /// A request from the server to sample an LLM via the client. - /// The client has full discretion over which model to select. - /// The client should also inform the user before beginning sampling, - /// to allow them to inspect the request (human in the loop) - /// and decide whether to approve it. - async fn create_message( - &self, - params: CreateMessageRequestParams, - ) -> SdkResult { - let ping_request = CreateMessageRequest::new(params); - let response = self.request(ping_request.into(), None).await?; - Ok(response.try_into()?) - } - /// Checks if the client supports sampling. /// /// This function retrieves the client information and checks if the @@ -284,7 +101,7 @@ pub trait McpServer: Sync + Send { request_method: &String, ) -> std::result::Result<(), RpcError> { let entity = "Client"; - if *request_method == CreateMessageRequest::method_name() + if *request_method == CreateMessageRequest::method_value() && !self.client_supports_sampling().unwrap_or(false) { return Err( @@ -295,7 +112,7 @@ pub trait McpServer: Sync + Send { )), ); } - if *request_method == ListRootsRequest::method_name() + if *request_method == ListRootsRequest::method_value() && !self.client_supports_root_list().unwrap_or(false) { return Err( @@ -317,7 +134,7 @@ pub trait McpServer: Sync + Send { let capabilities = &self.server_info().capabilities; - if *notification_method == LoggingMessageNotification::method_name() + if *notification_method == LoggingMessageNotification::method_value() && capabilities.logging.is_none() { return Err( @@ -328,7 +145,7 @@ pub trait McpServer: Sync + Send { )), ); } - if *notification_method == ResourceUpdatedNotification::method_name() + if *notification_method == ResourceUpdatedNotification::method_value() && capabilities.resources.is_none() { return Err( @@ -339,7 +156,7 @@ pub trait McpServer: Sync + Send { )), ); } - if *notification_method == ToolListChangedNotification::method_name() + if *notification_method == ToolListChangedNotification::method_value() && capabilities.tools.is_none() { return Err( @@ -350,7 +167,7 @@ pub trait McpServer: Sync + Send { )), ); } - if *notification_method == PromptListChangedNotification::method_name() + if *notification_method == PromptListChangedNotification::method_value() && capabilities.prompts.is_none() { return Err( @@ -367,12 +184,12 @@ pub trait McpServer: Sync + Send { fn assert_server_request_capabilities( &self, - request_method: &String, + request_method: &str, ) -> std::result::Result<(), RpcError> { let entity = "Server"; let capabilities = &self.server_info().capabilities; - if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() { + if request_method == SetLevelRequest::method_value() && capabilities.logging.is_none() { return Err( RpcError::internal_error().with_message(format_assertion_message( entity, @@ -382,10 +199,10 @@ pub trait McpServer: Sync + Send { ); } if [ - GetPromptRequest::method_name(), - ListPromptsRequest::method_name(), + GetPromptRequest::method_value(), + ListPromptsRequest::method_value(), ] - .contains(request_method) + .contains(&request_method) && capabilities.prompts.is_none() { return Err( @@ -397,11 +214,11 @@ pub trait McpServer: Sync + Send { ); } if [ - ListResourcesRequest::method_name(), - ListResourceTemplatesRequest::method_name(), - ReadResourceRequest::method_name(), + ListResourcesRequest::method_value(), + ListResourceTemplatesRequest::method_value(), + ReadResourceRequest::method_value(), ] - .contains(request_method) + .contains(&request_method) && capabilities.resources.is_none() { return Err( @@ -413,10 +230,10 @@ pub trait McpServer: Sync + Send { ); } if [ - CallToolRequest::method_name(), - ListToolsRequest::method_name(), + CallToolRequest::method_value(), + ListToolsRequest::method_value(), ] - .contains(request_method) + .contains(&request_method) && capabilities.tools.is_none() { return Err( @@ -432,4 +249,341 @@ pub trait McpServer: Sync + Send { #[cfg(feature = "hyper-server")] fn session_id(&self) -> Option; + + async fn send( + &self, + message: MessageFromServer, + request_id: Option, + request_timeout: Option, + ) -> SdkResult>; + + async fn send_batch( + &self, + messages: Vec, + request_timeout: Option, + ) -> SdkResult>>; + + /// Checks whether the server has been initialized with client + fn is_initialized(&self) -> bool { + self.client_info().is_some() + } + + /// Returns the client's name and version information once initialization is complete. + /// This method retrieves the client details, if available, after successful initialization. + fn client_version(&self) -> Option { + self.client_info() + .map(|client_details| client_details.client_info) + } + + /// Returns the server's capabilities. + fn capabilities(&self) -> &ServerCapabilities { + &self.server_info().capabilities + } + + /******************* + Requests + *******************/ + + /// Sends a request to the client and processes the response. + /// + /// This function sends a `RequestFromServer` message to the client, waits for the response, + /// and handles the result. If the response is empty or of an invalid type, an error is returned. + /// Otherwise, it returns the result from the client. + async fn request( + &self, + request: RequestFromServer, + timeout: Option, + ) -> SdkResult { + // Send the request and receive the response. + let response = self + .send(MessageFromServer::RequestFromServer(request), None, timeout) + .await?; + + let client_message = response.ok_or_else(|| { + RpcError::internal_error() + .with_message("An empty response was received from the client.".to_string()) + })?; + + if client_message.is_error() { + return Err(client_message.as_error()?.error.into()); + } + + return Ok(client_message.as_response()?.result); + } + + /// Sends an elicitation request to the client to prompt user input and returns the received response. + /// + /// The requested_schema argument allows servers to define the structure of the expected response using a restricted subset of JSON Schema. + /// To simplify client user experience, elicitation schemas are limited to flat objects with primitive properties only + async fn request_elicitation(&self, params: ElicitRequestParams) -> SdkResult { + let response = self + .request(RequestFromServer::ElicitRequest(params), None) + .await?; + ElicitResult::try_from(response).map_err(|err| err.into()) + } + + /// Request a list of root URIs from the client. Roots allow + /// servers to ask for specific directories or files to operate on. A common example + /// for roots is providing a set of repositories or directories a server should operate on. + /// This request is typically used when the server needs to understand the file system + /// structure or access specific locations that the client has permission to read from + async fn request_root_list(&self, params: Option) -> SdkResult { + let response = self + .request(RequestFromServer::ListRootsRequest(params), None) + .await?; + ListRootsResult::try_from(response).map_err(|err| err.into()) + } + + /// A ping request to check that the other party is still alive. + /// The receiver must promptly respond, or else may be disconnected. + /// + /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response + /// Once the response is received, it attempts to convert it into the expected + /// result type. + /// + /// # Returns + /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful. + /// If the request or conversion fails, an error is returned. + async fn ping( + &self, + params: Option, + timeout: Option, + ) -> SdkResult { + let response = self + .request(RequestFromServer::PingRequest(params), timeout) + .await?; + Ok(response.try_into()?) + } + + /// A request from the server to sample an LLM via the client. + /// The client has full discretion over which model to select. + /// The client should also inform the user before beginning sampling, + /// to allow them to inspect the request (human in the loop) + /// and decide whether to approve it. + async fn request_message_creation( + &self, + params: CreateMessageRequestParams, + ) -> SdkResult { + let response = self + .request(RequestFromServer::CreateMessageRequest(params), None) + .await?; + Ok(response.try_into()?) + } + + ///Send a request to retrieve the state of a task. + async fn request_get_task(&self, params: GetTaskParams) -> SdkResult { + let response = self + .request(RequestFromServer::GetTaskRequest(params), None) + .await?; + Ok(response.try_into()?) + } + + ///Send a request to retrieve the result of a completed task. + async fn request_get_task_payload( + &self, + params: GetTaskPayloadParams, + ) -> SdkResult { + let response = self + .request(RequestFromServer::GetTaskPayloadRequest(params), None) + .await?; + Ok(response.try_into()?) + } + + ///Send a request to cancel a task. + async fn request_task_cancellation( + &self, + params: CancelTaskParams, + ) -> SdkResult { + let response = self + .request(RequestFromServer::CancelTaskRequest(params), None) + .await?; + Ok(response.try_into()?) + } + + ///Send a custom request with a custom method name and params + async fn request_custom(&self, params: CustomRequest) -> SdkResult { + let response = self + .request(RequestFromServer::CustomRequest(params), None) + .await?; + Ok(response.try_into()?) + } + + /******************* + Notifications + *******************/ + + /// Sends a notification. This is a one-way message that is not expected + /// to return any response. The method asynchronously sends the notification using + /// the transport layer and does not wait for any acknowledgement or result. + async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> { + self.send( + MessageFromServer::NotificationFromServer(notification), + None, + None, + ) + .await?; + Ok(()) + } + + /// Send log message notification from server to client. + /// If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically. + async fn notify_log_message(&self, params: LoggingMessageNotificationParams) -> SdkResult<()> { + self.send_notification(NotificationFromServer::LoggingMessageNotification(params)) + .await + } + + ///Send an optional notification from the server to the client, informing it that + /// the list of prompts it offers has changed. + /// This may be issued by servers without any previous subscription from the client. + async fn notify_prompt_list_changed( + &self, + params: Option, + ) -> SdkResult<()> { + self.send_notification(NotificationFromServer::PromptListChangedNotification( + params, + )) + .await + } + + ///Send an optional notification from the server to the client, + /// informing it that the list of resources it can read from has changed. + /// This may be issued by servers without any previous subscription from the client. + async fn notify_resource_list_changed( + &self, + params: Option, + ) -> SdkResult<()> { + self.send_notification(NotificationFromServer::ResourceListChangedNotification( + params, + )) + .await + } + + ///Send a notification from the server to the client, informing it that + /// a resource has changed and may need to be read again. + /// This should only be sent if the client previously sent a resources/subscribe request. + async fn notify_resource_updated( + &self, + params: ResourceUpdatedNotificationParams, + ) -> SdkResult<()> { + self.send_notification(NotificationFromServer::ResourceUpdatedNotification(params)) + .await + } + + ///Send an optional notification from the server to the client, informing it that + /// the list of tools it offers has changed. + /// This may be issued by servers without any previous subscription from the client. + async fn notify_tool_list_changed(&self, params: Option) -> SdkResult<()> { + self.send_notification(NotificationFromServer::ToolListChangedNotification(params)) + .await + } + + /// This notification can be sent to indicate that it is cancelling a previously-issued request. + /// The request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification MAY arrive after the request has already finished. + /// This notification indicates that the result will be unused, so any associated processing SHOULD cease. + /// A client MUST NOT attempt to cancel its initialize request. + /// For task cancellation, use the tasks/cancel request instead of this notification. + async fn notify_cancellation(&self, params: CancelledNotificationParams) -> SdkResult<()> { + self.send_notification(NotificationFromServer::CancelledNotification(params)) + .await + } + + ///Send an out-of-band notification used to inform the receiver of a progress update for a long-running request. + async fn notify_progress(&self, params: ProgressNotificationParams) -> SdkResult<()> { + self.send_notification(NotificationFromServer::ProgressNotification(params)) + .await + } + + /// Send an optional notification from the receiver to the requestor, informing them that a task's status has changed. + /// Receivers are not required to send these notifications. + async fn notify_task_status(&self, params: TaskStatusNotificationParams) -> SdkResult<()> { + self.send_notification(NotificationFromServer::TaskStatusNotification(params)) + .await + } + + ///An optional notification from the server to the client, informing it of a completion of a out-of-band elicitation request. + async fn notify_elicitation_completed(&self, params: ElicitCompleteParams) -> SdkResult<()> { + self.send_notification(NotificationFromServer::ElicitationCompleteNotification( + params, + )) + .await + } + + ///Send a custom notification + async fn notify_custom(&self, params: CustomNotification) -> SdkResult<()> { + self.send_notification(NotificationFromServer::CustomNotification(params)) + .await + } + + #[deprecated(since = "0.8.0", note = "Use `request_root_list()` instead.")] + async fn list_roots(&self, params: Option) -> SdkResult { + let response = self + .request(RequestFromServer::ListRootsRequest(params), None) + .await?; + ListRootsResult::try_from(response).map_err(|err| err.into()) + } + + #[deprecated(since = "0.8.0", note = "Use `request_elicitation()` instead.")] + async fn elicit_input(&self, params: ElicitRequestParams) -> SdkResult { + let response = self + .request(RequestFromServer::ElicitRequest(params), None) + .await?; + ElicitResult::try_from(response).map_err(|err| err.into()) + } + + #[deprecated(since = "0.8.0", note = "Use `request_message_creation()` instead.")] + async fn create_message( + &self, + params: CreateMessageRequestParams, + ) -> SdkResult { + let response = self + .request(RequestFromServer::CreateMessageRequest(params), None) + .await?; + Ok(response.try_into()?) + } + + #[deprecated(since = "0.8.0", note = "Use `notify_tool_list_changed()` instead.")] + async fn send_tool_list_changed(&self, params: Option) -> SdkResult<()> { + self.send_notification(NotificationFromServer::ToolListChangedNotification(params)) + .await + } + + #[deprecated(since = "0.8.0", note = "Use `notify_resource_updated()` instead.")] + async fn send_resource_updated( + &self, + params: ResourceUpdatedNotificationParams, + ) -> SdkResult<()> { + self.send_notification(NotificationFromServer::ResourceUpdatedNotification(params)) + .await + } + + #[deprecated( + since = "0.8.0", + note = "Use `notify_resource_list_changed()` instead." + )] + async fn send_resource_list_changed( + &self, + params: Option, + ) -> SdkResult<()> { + self.send_notification(NotificationFromServer::ResourceListChangedNotification( + params, + )) + .await + } + + #[deprecated(since = "0.8.0", note = "Use `notify_prompt_list_changed()` instead.")] + async fn send_prompt_list_changed(&self, params: Option) -> SdkResult<()> { + self.send_notification(NotificationFromServer::PromptListChangedNotification( + params, + )) + .await + } + + #[deprecated(since = "0.8.0", note = "Use `notify_log_message()` instead.")] + async fn send_logging_message( + &self, + params: LoggingMessageNotificationParams, + ) -> SdkResult<()> { + self.send_notification(NotificationFromServer::LoggingMessageNotification(params)) + .await + } } diff --git a/crates/rust-mcp-sdk/src/schema.rs b/crates/rust-mcp-sdk/src/schema.rs index bc008c2..77fe186 100644 --- a/crates/rust-mcp-sdk/src/schema.rs +++ b/crates/rust-mcp-sdk/src/schema.rs @@ -1,17 +1,3 @@ -#[cfg(feature = "2025_06_18")] -pub use rust_mcp_schema::*; - -#[cfg(not(feature = "2025_06_18"))] +pub use rust_mcp_schema::mcp_2025_11_25::*; +// always export pub use rust_mcp_schema::{ParseProtocolVersionError, ProtocolVersion}; - -#[cfg(all( - feature = "2025_03_26", - not(any(feature = "2024_11_05", feature = "2025_06_18")) -))] -pub use rust_mcp_schema::mcp_2025_03_26::*; - -#[cfg(all( - feature = "2024_11_05", - not(any(feature = "2025_03_26", feature = "2025_06_18")) -))] -pub use rust_mcp_schema::mcp_2024_11_05::*; diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index c63010d..bec1577 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -38,7 +38,7 @@ impl Drop for AbortTaskOnDrop { /// /// # Examples /// ```ignore -/// let msg = format_assertion_message("Server", "tools", rust_mcp_schema::ListResourcesRequest::method_name()); +/// let msg = format_assertion_message("Server", "tools", rust_mcp_schema::ListResourcesRequest::method_value()); /// assert_eq!(msg, "Server does not support resources (required for resources/list)"); /// ``` pub fn format_assertion_message(entity: &str, capability: &str, method_name: &str) -> String { diff --git a/crates/rust-mcp-sdk/tests/check_imports.rs b/crates/rust-mcp-sdk/tests/check_imports.rs index 207644e..b7ccbb6 100644 --- a/crates/rust-mcp-sdk/tests/check_imports.rs +++ b/crates/rust-mcp-sdk/tests/check_imports.rs @@ -1,87 +1,87 @@ -#[cfg(test)] -mod tests { - use std::fs::File; - use std::io::{self, Read}; - use std::path::{Path, MAIN_SEPARATOR_STR}; +// #[cfg(test)] +// mod tests { +// use std::fs::File; +// use std::io::{self, Read}; +// use std::path::{Path, MAIN_SEPARATOR_STR}; - // List of files to exclude from the check - const EXCLUDED_FILES: &[&str] = &["src/schema.rs"]; +// // List of files to exclude from the check +// const EXCLUDED_FILES: &[&str] = &["src/schema.rs"]; - // Check all .rs files for incorrect `use rust_mcp_schema` imports - #[test] - fn check_no_rust_mcp_schema_imports() { - let mut errors = Vec::new(); +// // Check all .rs files for incorrect `use rust_mcp_schema` imports +// #[test] +// fn check_no_rust_mcp_schema_imports() { +// let mut errors = Vec::new(); - // Walk through the src directory - for entry in walk_src_dir("src").expect("Failed to read src directory") { - let entry = entry.unwrap(); - let path = entry.path(); +// // Walk through the src directory +// for entry in walk_src_dir("src").expect("Failed to read src directory") { +// let entry = entry.unwrap(); +// let path = entry.path(); - // only check files with .rs extension - if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("rs") { - let abs_path = path.to_string_lossy(); - let relative_path = path.strip_prefix("src").unwrap_or(&path); - let path_str = relative_path.to_string_lossy(); +// // only check files with .rs extension +// if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("rs") { +// let abs_path = path.to_string_lossy(); +// let relative_path = path.strip_prefix("src").unwrap_or(&path); +// let path_str = relative_path.to_string_lossy(); - // Skip excluded files - if EXCLUDED_FILES - .iter() - .any(|&excluded| abs_path.replace(MAIN_SEPARATOR_STR, "/") == excluded) - { - continue; - } +// // Skip excluded files +// if EXCLUDED_FILES +// .iter() +// .any(|&excluded| abs_path.replace(MAIN_SEPARATOR_STR, "/") == excluded) +// { +// continue; +// } - // Read the file content - match read_file(&path) { - Ok(content) => { - // Check for `use rust_mcp_schema` - if content.contains("use rust_mcp_schema") { - errors.push(format!( - "File {abs_path} contains `use rust_mcp_schema`. Use `use crate::schema` instead." - )); - } - } - Err(e) => { - errors.push(format!("Failed to read file `{path_str}`: {e}")); - } - } - } - } +// // Read the file content +// match read_file(&path) { +// Ok(content) => { +// // Check for `use rust_mcp_schema` +// if content.contains("use rust_mcp_schema") { +// errors.push(format!( +// "File {abs_path} contains `use rust_mcp_schema`. Use `use crate::schema` instead." +// )); +// } +// } +// Err(e) => { +// errors.push(format!("Failed to read file `{path_str}`: {e}")); +// } +// } +// } +// } - // If there are any errors, fail the test with all error messages - if !errors.is_empty() { - panic!( - "Found {} incorrect imports:\n{}\n\n", - errors.len(), - errors.join("\n") - ); - } - } +// // If there are any errors, fail the test with all error messages +// if !errors.is_empty() { +// panic!( +// "Found {} incorrect imports:\n{}\n\n", +// errors.len(), +// errors.join("\n") +// ); +// } +// } - // Helper function to walk the src directory - fn walk_src_dir>( - path: P, - ) -> io::Result>> { - Ok(std::fs::read_dir(path)?.flat_map(|entry| { - let entry = entry.unwrap(); - let path = entry.path(); - if path.is_dir() { - // Recursively walk subdirectories - walk_src_dir(&path) - .into_iter() - .flatten() - .collect::>() - } else { - vec![Ok(entry)] - } - })) - } +// // Helper function to walk the src directory +// fn walk_src_dir>( +// path: P, +// ) -> io::Result>> { +// Ok(std::fs::read_dir(path)?.flat_map(|entry| { +// let entry = entry.unwrap(); +// let path = entry.path(); +// if path.is_dir() { +// // Recursively walk subdirectories +// walk_src_dir(&path) +// .into_iter() +// .flatten() +// .collect::>() +// } else { +// vec![Ok(entry)] +// } +// })) +// } - // Helper function to read file content - fn read_file(path: &Path) -> io::Result { - let mut file = File::open(path)?; - let mut content = String::new(); - file.read_to_string(&mut content)?; - Ok(content) - } -} +// // Helper function to read file content +// fn read_file(path: &Path) -> io::Result { +// let mut file = File::open(path)?; +// let mut content = String::new(); +// file.read_to_string(&mut content)?; +// Ok(content) +// } +// } diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 8e61704..cc21361 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -6,9 +6,9 @@ pub use mock_server::*; use reqwest::{Client, Response, Url}; use rust_mcp_macros::{mcp_tool, JsonSchema}; use rust_mcp_schema::ProtocolVersion; -use rust_mcp_sdk::mcp_client::ClientHandler; - use rust_mcp_sdk::auth::{AuthInfo, AuthenticationError, OauthTokenVerifier}; +use rust_mcp_sdk::mcp_client::ClientHandler; +use rust_mcp_sdk::mcp_icon; use rust_mcp_sdk::schema::{ClientCapabilities, Implementation, InitializeRequestParams}; use std::collections::HashMap; use std::process; @@ -243,10 +243,18 @@ pub fn test_client_info() -> InitializeRequestParams { client_info: Implementation { name: "test-rust-mcp-client".into(), version: "0.1.0".into(), - #[cfg(feature = "2025_06_18")] + description: None, + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], title: None, + website_url: None, }, protocol_version: ProtocolVersion::V2025_03_26.to_string(), + meta: None, } } @@ -331,9 +339,8 @@ pub fn random_port_old() -> u16 { pub mod sample_tools { use std::{sync::Arc, time::Duration}; + use rust_mcp_macros::{mcp_tool, JsonSchema}; use rust_mcp_schema::{LoggingMessageNotificationParams, TextContent}; - #[cfg(feature = "2025_06_18")] - use rust_mcp_sdk::macros::{mcp_tool, JsonSchema}; use rust_mcp_sdk::{ schema::{schema_utils::CallToolError, CallToolResult}, McpServer, @@ -351,7 +358,7 @@ pub mod sample_tools { open_world_hint = false, read_only_hint = false )] - #[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema)] + #[derive(Debug, ::serde::Deserialize, ::serde::Serialize, rust_mcp_macros::JsonSchema)] pub struct SayHelloTool { /// The name of the person to greet with a "Hello". pub name: String, @@ -361,12 +368,7 @@ pub mod sample_tools { pub fn call_tool(&self) -> Result { let hello_message = format!("Hello, {}!", self.name); - #[cfg(feature = "2025_06_18")] - return Ok(CallToolResult::text_content(vec![ - rust_mcp_sdk::schema::TextContent::from(hello_message), - ])); - #[cfg(not(feature = "2025_06_18"))] - return Ok(CallToolResult::text_content(hello_message, None)); + return Ok(CallToolResult::text_content(vec![hello_message.into()])); } } @@ -390,12 +392,8 @@ pub mod sample_tools { auth_info: Option, ) -> Result { let message = format!("{}", serde_json::to_string(&auth_info).unwrap()); - #[cfg(feature = "2025_06_18")] - return Ok(CallToolResult::text_content(vec![ - rust_mcp_sdk::schema::TextContent::from(message), - ])); - #[cfg(not(feature = "2025_06_18"))] - return Ok(CallToolResult::text_content(message, None)); + + return Ok(CallToolResult::text_content(vec![message.into()])); } } @@ -419,12 +417,7 @@ pub mod sample_tools { pub fn call_tool(&self) -> Result { let goodbye_message = format!("Goodbye, {}!", self.name); - #[cfg(feature = "2025_06_18")] - return Ok(CallToolResult::text_content(vec![ - rust_mcp_sdk::schema::TextContent::from(goodbye_message), - ])); - #[cfg(not(feature = "2025_06_18"))] - return Ok(CallToolResult::text_content(goodbye_message, None)); + return Ok(CallToolResult::text_content(vec![goodbye_message.into()])); } } @@ -453,6 +446,7 @@ pub mod sample_tools { data: json!({"id":format!("message {} of {}",i,self.count)}), level: rust_mcp_sdk::schema::LoggingLevel::Emergency, logger: None, + meta: None, }) .await; tokio::time::sleep(Duration::from_millis(self.interval)).await; diff --git a/crates/rust-mcp-sdk/tests/common/test_client.rs b/crates/rust-mcp-sdk/tests/common/test_client.rs index 912501a..b30c17a 100644 --- a/crates/rust-mcp-sdk/tests/common/test_client.rs +++ b/crates/rust-mcp-sdk/tests/common/test_client.rs @@ -1,5 +1,8 @@ use async_trait::async_trait; -use rust_mcp_schema::{schema_utils::MessageFromServer, PingRequest, RpcError}; +use rust_mcp_schema::{ + schema_utils::{MessageFromServer, RequestFromServer}, + PingRequest, RequestParams, RpcError, +}; use rust_mcp_sdk::{mcp_client::ClientHandler, McpClient}; use serde_json::json; use std::sync::Arc; @@ -13,7 +16,7 @@ pub mod test_client_common { }; use rust_mcp_sdk::{ mcp_client::{client_runtime, ClientRuntime}, - McpClient, RequestOptions, SessionId, StreamableTransportOptions, + mcp_icon, McpClient, RequestOptions, SessionId, StreamableTransportOptions, }; use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::sync::RwLock; @@ -34,7 +37,7 @@ pub mod test_client_common { } pub const TEST_SESSION_ID: &str = "test-session-id"; - pub const INITIALIZE_REQUEST: &str = r#"{"id":0,"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{},"clientInfo":{"name":"simple-rust-mcp-client-sse","title":"Simple Rust MCP Client (SSE)","version":"0.1.0"},"protocolVersion":"2025-06-18"}}"#; + pub const INITIALIZE_REQUEST: &str = r#"{"id":0,"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{},"clientInfo":{"icons":[{"mimeType":"image/png","sizes":["128x128"],"src":"https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png","theme":"dark"}],"name":"simple-rust-mcp-client-sse","title":"Simple Rust MCP Client (SSE)","version":"0.1.0"},"protocolVersion":"2025-11-25"}}"#; pub fn test_client_details() -> InitializeRequestParams { InitializeRequestParams { @@ -43,8 +46,17 @@ pub mod test_client_common { name: "simple-rust-mcp-client-sse".to_string(), version: "0.1.0".to_string(), title: Some("Simple Rust MCP Client (SSE)".to_string()), + description: None, + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: None, }, protocol_version: LATEST_PROTOCOL_VERSION.into(), + meta: None, } } @@ -147,10 +159,13 @@ impl TestClientHandler { impl ClientHandler for TestClientHandler { async fn handle_ping_request( &self, - request: PingRequest, + params: Option, _runtime: &dyn McpClient, ) -> std::result::Result { - self.register_message(&request.into()).await; + self.register_message(&MessageFromServer::RequestFromServer( + RequestFromServer::PingRequest(params), + )) + .await; Ok(rust_mcp_schema::Result { meta: Some(json!({"meta_number":1515}).as_object().unwrap().to_owned()), diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index 9c8e6ee..4b5d2d0 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -2,20 +2,23 @@ pub mod test_server_common { use crate::common::sample_tools::{DisplayAuthInfo, SayHelloTool}; use async_trait::async_trait; - use rust_mcp_schema::schema_utils::CallToolError; + use rust_mcp_schema::schema_utils::{CallToolError, RequestFromClient}; use rust_mcp_schema::{ - CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, ProtocolVersion, - RpcError, + CallToolRequest, CallToolRequestParams, CallToolResult, ListToolsRequest, ListToolsResult, + PaginatedRequestParams, ProtocolVersion, RpcError, }; use rust_mcp_sdk::event_store::EventStore; use rust_mcp_sdk::id_generator::IdGenerator; + use rust_mcp_sdk::mcp_icon; use rust_mcp_sdk::mcp_server::hyper_runtime::HyperRuntime; use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequest, InitializeRequestParams, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, }; use rust_mcp_sdk::{ - mcp_server::{hyper_server, HyperServer, HyperServerOptions, ServerHandler}, + mcp_server::{ + hyper_server, HyperServer, HyperServerOptions, ServerHandler, ToMcpServerHandler, + }, McpServer, SessionId, }; use std::sync::{Arc, RwLock}; @@ -23,9 +26,9 @@ pub mod test_server_common { use tokio::time::timeout; use tokio_stream::StreamExt; - pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; + pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; - pub const INITIALIZE_RESPONSE: &str = r#"{"result":{"protocolVersion":"2025-06-18","capabilities":{"prompts":{},"resources":{"subscribe":true},"tools":{},"logging":{}},"serverInfo":{"name":"example-servers/everything","version":"1.0.0"}},"jsonrpc":"2.0","id":0}"#; + pub const INITIALIZE_RESPONSE: &str = r#"{"result":{"protocolVersion":"2025-11-25","capabilities":{"prompts":{},"resources":{"subscribe":true},"tools":{},"logging":{}},"serverInfo":{"name":"example-servers/everything","version":"1.0.0"}},"jsonrpc":"2.0","id":0}"#; pub struct LaunchedServer { pub hyper_runtime: HyperRuntime, @@ -35,8 +38,8 @@ pub mod test_server_common { pub event_store: Option>, } - pub fn initialize_request() -> InitializeRequest { - InitializeRequest::new(InitializeRequestParams { + pub fn initialize_request() -> RequestFromClient { + RequestFromClient::InitializeRequest(InitializeRequestParams { capabilities: ClientCapabilities { ..Default::default() }, @@ -44,8 +47,17 @@ pub mod test_server_common { name: "test-server".to_string(), title: None, version: "0.1.0".to_string(), + description: None, + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: None, }, - protocol_version: ProtocolVersion::V2025_06_18.to_string(), + protocol_version: ProtocolVersion::V2025_11_25.to_string(), + meta: None, }) } @@ -55,8 +67,15 @@ pub mod test_server_common { server_info: Implementation { name: "Test MCP Server".to_string(), version: "0.1.0".to_string(), - #[cfg(feature = "2025_06_18")] title: None, + description: None, + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: None, }, capabilities: ServerCapabilities { // indicates that server support mcp tools @@ -65,7 +84,7 @@ pub mod test_server_common { }, meta: None, instructions: Some("server instructions...".to_string()), - protocol_version: ProtocolVersion::V2025_06_18.to_string(), + protocol_version: ProtocolVersion::V2025_11_25.to_string(), } } @@ -75,10 +94,10 @@ pub mod test_server_common { impl ServerHandler for TestServerHandler { async fn handle_list_tools_request( &self, - request: ListToolsRequest, + _params: Option, runtime: Arc, ) -> std::result::Result { - runtime.assert_server_request_capabilities(request.method())?; + runtime.assert_server_request_capabilities(&ListToolsRequest::method_value())?; Ok(ListToolsResult { meta: None, @@ -89,17 +108,17 @@ pub mod test_server_common { async fn handle_call_tool_request( &self, - request: CallToolRequest, + params: CallToolRequestParams, runtime: Arc, ) -> std::result::Result { runtime - .assert_server_request_capabilities(request.method()) + .assert_server_request_capabilities(&CallToolRequest::method_value()) .map_err(CallToolError::new)?; - match request.params.name.as_str() { + match params.name.as_str() { "say_hello" => { let tool = SayHelloTool { - name: request.params.arguments.unwrap()["name"] + name: params.arguments.unwrap()["name"] .as_str() .unwrap() .to_string(), @@ -111,17 +130,19 @@ pub mod test_server_common { let tool = DisplayAuthInfo {}; Ok(tool.call_tool(runtime.auth_info_cloned().await).unwrap()) } - _ => Ok(CallToolError::unknown_tool(format!( - "Unknown tool: {}", - request.params.name - )) - .into()), + _ => Ok( + CallToolError::unknown_tool(format!("Unknown tool: {}", params.name)).into(), + ), } } } pub fn create_test_server(options: HyperServerOptions) -> HyperServer { - hyper_server::create_server(test_server_details(), TestServerHandler {}, options) + hyper_server::create_server( + test_server_details(), + TestServerHandler {}.to_mcp_server_handler(), + options, + ) } pub async fn create_start_server(options: HyperServerOptions) -> LaunchedServer { @@ -130,8 +151,11 @@ pub mod test_server_common { let sse_message_url = options.sse_message_url(); let event_store_clone = options.event_store.clone(); - let server = - hyper_server::create_server(test_server_details(), TestServerHandler {}, options); + let server = hyper_server::create_server( + test_server_details(), + TestServerHandler {}.to_mcp_server_handler(), + options, + ); let hyper_runtime = HyperRuntime::create(server).await.unwrap(); @@ -197,12 +221,12 @@ pub mod test_server_common { // Check if we have collected 5 lines if collected_lines.len() >= line_count { - return Ok(collected_lines); + return Ok(collected_lines.clone()); } } } // If the stream ends before collecting 5 lines, return what we have - Ok(collected_lines) + Ok(collected_lines.clone()) }) .await; @@ -212,7 +236,11 @@ pub mod test_server_common { Ok(Err(e)) => Err(e), Err(_) => Err(Box::new(std::io::Error::new( std::io::ErrorKind::TimedOut, - "Timed out waiting for 5 lines", + format!( + "Timed out waiting for 5 lines, received({}): {}", + collected_lines.len(), + collected_lines.join(" \n ") + ), ))), } } diff --git a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs index 9f2fd95..e6cff35 100644 --- a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs +++ b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs @@ -3,8 +3,8 @@ pub mod common; mod protocol_compatibility_on_server { - use rust_mcp_sdk::mcp_server::ServerHandler; - use rust_mcp_sdk::schema::{InitializeRequest, InitializeResult, RpcError, INTERNAL_ERROR}; + use rust_mcp_sdk::mcp_server::{ServerHandler, ToMcpServerHandler}; + use rust_mcp_sdk::schema::{InitializeResult, RpcError, INTERNAL_ERROR}; use crate::common::{ test_client_info, @@ -26,11 +26,11 @@ mod protocol_compatibility_on_server { let runtime = rust_mcp_sdk::mcp_server::server_runtime::create_server( test_server_details(), transport, - TestServerHandler {}, + TestServerHandler {}.to_mcp_server_handler(), ); handler - .handle_initialize_request(InitializeRequest::new(initialize_request), runtime) + .handle_initialize_request(initialize_request, runtime) .await } diff --git a/crates/rust-mcp-sdk/tests/test_server_sse.rs b/crates/rust-mcp-sdk/tests/test_server_sse.rs index 1148cca..5e4912a 100644 --- a/crates/rust-mcp-sdk/tests/test_server_sse.rs +++ b/crates/rust-mcp-sdk/tests/test_server_sse.rs @@ -208,8 +208,7 @@ mod tets_server_sse { let result = serde_json::from_str::(&init_response).unwrap(); assert!(matches!(result, ServerMessage::Response(response) - if matches!(&response.result, ResultFromServer::ServerResult(server_result) - if matches!(server_result, ServerResult::InitializeResult(init_result) if init_result.server_info.name == "Test MCP Server")))); + if matches!(&response.result, ServerResult::InitializeResult(init_result) if init_result.server_info.name == "Test MCP Server"))); handle.graceful_shutdown(Some(Duration::from_millis(1))); server_task.await.unwrap(); } diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs index ceb778a..a6bcf63 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -5,7 +5,7 @@ use common::test_client_common::create_client; use hyper::{Method, StatusCode}; use rust_mcp_schema::{ schema_utils::{ - ClientJsonrpcRequest, ClientMessage, MessageFromServer, RequestFromClient, + ClientJsonrpcRequest, ClientMessage, CustomRequest, MessageFromServer, RequestFromClient, RequestFromServer, ResultFromServer, RpcMessage, ServerMessage, }, RequestId, ServerRequest, ServerResult, @@ -14,7 +14,7 @@ use rust_mcp_sdk::{ error::McpSdkError, mcp_server::HyperServerOptions, McpClient, TransportError, MCP_LAST_EVENT_ID_HEADER, }; -use serde_json::{json, Value}; +use serde_json::{json, Map, Value}; use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration}; use wiremock::{ http::{HeaderName, HeaderValue}, @@ -101,12 +101,18 @@ async fn should_send_batch_messages() { let message_1: ClientMessage = ClientJsonrpcRequest::new( RequestId::String("id1".to_string()), - RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + RequestFromClient::CustomRequest(CustomRequest { + method: "test1".to_string(), + params: Some(Map::new()), + }), ) .into(); let message_2: ClientMessage = ClientJsonrpcRequest::new( RequestId::String("id2".to_string()), - RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + RequestFromClient::CustomRequest(CustomRequest { + method: "test2".to_string(), + params: Some(Map::new()), + }), ) .into(); @@ -235,7 +241,7 @@ async fn should_handle_404_response_when_session_expires() { .mount(&mock_server) .await; - let result = client.ping(None).await; + let result = client.ping(None, None).await; matches!( result, @@ -265,11 +271,16 @@ async fn should_handle_non_streaming_json_response() { .mount(&mock_server) .await; - let request = RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})); + let request = RequestFromClient::CustomRequest(CustomRequest { + method: "test1".to_string(), + params: Some(Map::new()), + }); let result = client.request(request, None).await.unwrap(); - let ResultFromServer::ServerResult(ServerResult::Result(result)) = result else { + println!(">>> result {:?} ", result); + + let ResultFromServer::Result(result) = result else { panic!("Wrong result variant!") }; @@ -360,7 +371,7 @@ async fn should_receive_server_initiated_messaged() { tokio::time::sleep(Duration::from_secs(1)).await; let result = hyper_runtime - .ping(&"AAA-BBB-CCC".to_string(), None) + .ping(&"AAA-BBB-CCC".to_string(), None, None) .await .unwrap(); @@ -370,15 +381,11 @@ async fn should_receive_server_initiated_messaged() { .find(|m| { matches!( m, - MessageFromServer::RequestFromServer(RequestFromServer::ServerRequest( - ServerRequest::PingRequest(_) - )) + MessageFromServer::RequestFromServer(RequestFromServer::PingRequest(_)) ) }) .unwrap(); - let MessageFromServer::RequestFromServer(RequestFromServer::ServerRequest( - ServerRequest::PingRequest(_), - )) = ping_request + let MessageFromServer::RequestFromServer(RequestFromServer::PingRequest(_)) = ping_request else { panic!("Request is not a match!") }; @@ -461,12 +468,18 @@ async fn should_attempt_initial_get_connection_and_handle_405_gracefully() { let message_1: ClientMessage = ClientJsonrpcRequest::new( RequestId::String("id1".to_string()), - RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + RequestFromClient::CustomRequest(CustomRequest { + method: "test1".to_string(), + params: Some(Map::new()), + }), ) .into(); let message_2: ClientMessage = ClientJsonrpcRequest::new( RequestId::String("id2".to_string()), - RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + RequestFromClient::CustomRequest(CustomRequest { + method: "test2".to_string(), + params: Some(Map::new()), + }), ) .into(); @@ -495,12 +508,18 @@ async fn should_handle_multiple_concurrent_sse_streams() { let message_1: ClientMessage = ClientJsonrpcRequest::new( RequestId::String("id1".to_string()), - RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + RequestFromClient::CustomRequest(CustomRequest { + method: "test1".to_string(), + params: Some(Map::new()), + }), ) .into(); let message_2: ClientMessage = ClientJsonrpcRequest::new( RequestId::String("id2".to_string()), - RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + RequestFromClient::CustomRequest(CustomRequest { + method: "test2".to_string(), + params: Some(Map::new()), + }), ) .into(); @@ -520,7 +539,10 @@ async fn should_handle_multiple_concurrent_sse_streams() { .mount(&mock_server) .await; - let message_3 = RequestFromClient::CustomRequest(json!({"method": "test3", "params": {}})); + let message_3 = RequestFromClient::CustomRequest(CustomRequest { + method: "test3".to_string(), + params: Some(Map::new()), + }); let request1 = client.send_batch(vec![message_1, message_2], None); let request2 = client.send(message_3.into(), None, None); @@ -567,7 +589,7 @@ async fn should_throw_error_when_invalid_content_type_is_received() { .mount(&mock_server) .await; - let result = client.ping(None).await; + let result = client.ping(None, None).await; let Err(McpSdkError::Transport(TransportError::UnexpectedContentType(content_type))) = result else { @@ -598,7 +620,7 @@ async fn should_always_send_specified_custom_headers() { .mount(&mock_server) .await; - let _result = client.ping(None).await; + let _result = client.ping(None, None).await; let requests = mock_server.received_requests().await.unwrap(); diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index 43f162d..70fd81c 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -8,14 +8,16 @@ use crate::common::{ }; use http::header::{ACCEPT, ACCESS_CONTROL_ALLOW_ORIGIN, AUTHORIZATION, CONTENT_TYPE}; use hyper::StatusCode; +use rust_mcp_macros::{mcp_elicit, JsonSchema}; use rust_mcp_schema::{ schema_utils::{ ClientJsonrpcRequest, ClientJsonrpcResponse, ClientMessage, ClientMessages, FromMessage, - NotificationFromServer, RequestFromServer, ResultFromServer, RpcMessage, SdkError, - SdkErrorCodes, ServerJsonrpcNotification, ServerJsonrpcRequest, ServerJsonrpcResponse, - ServerMessages, + MessageFromClient, NotificationFromClient, NotificationFromServer, RequestFromClient, + RequestFromServer, ResultFromServer, RpcMessage, SdkError, SdkErrorCodes, + ServerJsonrpcNotification, ServerJsonrpcRequest, ServerJsonrpcResponse, ServerMessages, }, - CallToolRequest, CallToolRequestParams, ListRootsResult, ListToolsRequest, LoggingLevel, + CallToolRequest, CallToolRequestParams, ClientRequest, ElicitResult, ElicitResultContent, + JsonrpcResultResponse, ListRootsRequest, ListRootsResult, ListToolsRequest, LoggingLevel, LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, ServerRequest, ServerResult, }; @@ -23,6 +25,7 @@ use rust_mcp_sdk::{ auth::{AuthInfo, AuthMetadataBuilder, AuthProvider, RemoteAuthProvider}, event_store::InMemoryEventStore, mcp_server::HyperServerOptions, + schema::{MessageFromServer, ResultFromClient}, }; use serde_json::{json, Map, Value}; use std::{ @@ -45,7 +48,7 @@ async fn initialize_server( auth_token_map: Option>, ) -> Result<(LaunchedServer, String), Box> { let json_rpc_message: ClientJsonrpcRequest = - ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request().into()); + ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request()); let port = random_port(); @@ -223,8 +226,10 @@ async fn should_reject_batch_initialize_request() { async fn should_handle_post_requests_via_sse_response_correctly() { let (server, session_id) = initialize_server(None, None).await.unwrap(); - let json_rpc_message: ClientJsonrpcRequest = - ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + RequestFromClient::ListToolsRequest(None), + ); let response = send_post_request( &server.streamable_url, @@ -242,8 +247,7 @@ async fn should_handle_post_requests_via_sse_response_correctly() { assert!(matches!(message.id, RequestId::Integer(1))); - let ResultFromServer::ServerResult(ServerResult::ListToolsResult(result)) = message.result - else { + let ResultFromServer::ListToolsResult(result) = message.result else { panic!("invalid ListToolsResult") }; @@ -270,9 +274,11 @@ async fn should_call_a_tool_and_return_the_result() { let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( RequestId::Integer(1), - CallToolRequest::new(CallToolRequestParams { + RequestFromClient::CallToolRequest(CallToolRequestParams { arguments: Some(map), name: "say_hello".to_string(), + meta: None, + task: None, }) .into(), ); @@ -293,8 +299,7 @@ async fn should_call_a_tool_and_return_the_result() { assert!(matches!(message.id, RequestId::Integer(1))); - let ResultFromServer::ServerResult(ServerResult::CallToolResult(result)) = message.result - else { + let ResultFromServer::CallToolResult(result) = message.result else { panic!("invalid CallToolResult") }; @@ -313,8 +318,10 @@ async fn should_call_a_tool_and_return_the_result() { async fn should_reject_requests_without_a_valid_session_id() { let (server, _session_id) = initialize_server(None, None).await.unwrap(); - let json_rpc_message: ClientJsonrpcRequest = - ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + RequestFromClient::ListToolsRequest(None), + ); let response = send_post_request( &server.streamable_url, @@ -338,8 +345,10 @@ async fn should_reject_requests_without_a_valid_session_id() { async fn should_reject_invalid_session_id() { let (server, _session_id) = initialize_server(None, None).await.unwrap(); - let json_rpc_message: ClientJsonrpcRequest = - ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + RequestFromClient::ListToolsRequest(None), + ); let response = send_post_request( &server.streamable_url, @@ -416,6 +425,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { data: json!("Test notification"), level: rust_mcp_schema::LoggingLevel::Info, logger: None, + meta: None, }, ) .await @@ -424,10 +434,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { let events = read_sse_event(response, 1).await.unwrap(); let message: ServerJsonrpcNotification = serde_json::from_str(&events[0].2).unwrap(); - let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( - notification, - )) = message.notification - else { + let ServerJsonrpcNotification::LoggingMessageNotification(notification) = message else { panic!("invalid message received!"); }; @@ -508,20 +515,18 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0].2).unwrap(); - let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request - else { + let ServerJsonrpcRequest::ListRootsRequest(_) = message1 else { panic!("invalid message received!"); }; let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1].2).unwrap(); - let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request - else { + let ServerJsonrpcRequest::ListRootsRequest(_) = message1 else { panic!("invalid message received!"); }; // ensure request_ids are unique - assert!(message2.id != message1.id); + assert!(message2.request_id() != message1.request_id()); hyper_server.graceful_shutdown(ONE_MILLISECOND); } @@ -542,6 +547,7 @@ async fn should_not_close_get_sse_stream() { data: json!("First notification"), level: rust_mcp_schema::LoggingLevel::Info, logger: None, + meta: None, }, ) .await @@ -551,10 +557,7 @@ async fn should_not_close_get_sse_stream() { let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); let message: ServerJsonrpcNotification = serde_json::from_str(&event.2).unwrap(); - let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( - notification, - )) = message.notification - else { + let ServerJsonrpcNotification::LoggingMessageNotification(notification) = message else { panic!("invalid message received!"); }; @@ -572,6 +575,7 @@ async fn should_not_close_get_sse_stream() { data: json!("Second notification"), level: rust_mcp_schema::LoggingLevel::Info, logger: None, + meta: None, }, ) .await @@ -580,10 +584,7 @@ async fn should_not_close_get_sse_stream() { let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); let message: ServerJsonrpcNotification = serde_json::from_str(&event.2).unwrap(); - let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( - notification_2, - )) = message.notification - else { + let ServerJsonrpcNotification::LoggingMessageNotification(notification_2) = message else { panic!("invalid message received!"); }; @@ -642,8 +643,10 @@ async fn should_reject_get_requests() { async fn reject_post_requests_without_accept_header() { let (server, session_id) = initialize_server(None, None).await.unwrap(); - let json_rpc_message: ClientJsonrpcRequest = - ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + RequestFromClient::ListToolsRequest(None).into(), + ); let mut headers = HashMap::new(); headers.insert("Accept", "application/json"); @@ -676,8 +679,10 @@ async fn reject_post_requests_without_accept_header() { async fn should_reject_unsupported_content_type() { let (server, session_id) = initialize_server(None, None).await.unwrap(); - let json_rpc_message: ClientJsonrpcRequest = - ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + RequestFromClient::ListToolsRequest(None), + ); let mut headers = HashMap::new(); headers.insert("Content-Type", "text/plain"); @@ -712,8 +717,20 @@ async fn should_handle_batch_notification_messages_with_202_response() { let (server, session_id) = initialize_server(None, None).await.unwrap(); let batch_notification = ClientMessages::Batch(vec![ - ClientMessage::from_message(RootsListChangedNotification::new(None), None).unwrap(), - ClientMessage::from_message(RootsListChangedNotification::new(None), None).unwrap(), + ClientMessage::from_message( + MessageFromClient::NotificationFromClient( + NotificationFromClient::RootsListChangedNotification(None), + ), + None, + ) + .unwrap(), + ClientMessage::from_message( + MessageFromClient::NotificationFromClient( + NotificationFromClient::RootsListChangedNotification(None), + ), + None, + ) + .unwrap(), ]); let response = send_post_request( @@ -753,17 +770,21 @@ async fn should_properly_handle_invalid_json_data() { async fn should_send_response_messages_to_the_connection_that_sent_the_request() { let (server, session_id) = initialize_server(None, None).await.unwrap(); - let json_rpc_message_1: ClientJsonrpcRequest = - ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + let json_rpc_message_1: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + RequestFromClient::ListToolsRequest(None).into(), + ); let mut map = Map::new(); map.insert("name".to_string(), Value::String("Ali".to_string())); let json_rpc_message_2: ClientJsonrpcRequest = ClientJsonrpcRequest::new( RequestId::Integer(1), - CallToolRequest::new(CallToolRequestParams { + RequestFromClient::CallToolRequest(CallToolRequestParams { arguments: Some(map), name: "say_hello".to_string(), + meta: None, + task: None, }) .into(), ); @@ -794,8 +815,7 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() assert!(matches!(message.id, RequestId::Integer(1))); - let ResultFromServer::ServerResult(ServerResult::CallToolResult(result)) = message.result - else { + let ResultFromServer::CallToolResult(result) = message.result else { panic!("invalid CallToolResult") }; @@ -810,8 +830,7 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() assert!(matches!(message.id, RequestId::Integer(1))); - let ResultFromServer::ServerResult(ServerResult::ListToolsResult(result)) = message.result - else { + let ResultFromServer::ListToolsResult(result) = message.result else { panic!("invalid ListToolsResult") }; @@ -892,8 +911,10 @@ async fn should_accept_requests_without_protocol_version_header() { headers.insert("Content-Type", "application/json"); headers.insert("Accept", "application/json, text/event-stream"); - let json_rpc_message: ClientJsonrpcRequest = - ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + RequestFromClient::ListToolsRequest(None).into(), + ); let response = send_post_request( &server.streamable_url, @@ -920,8 +941,10 @@ async fn should_reject_requests_with_unsupported_protocol_version() { headers.insert("Accept", "application/json, text/event-stream"); headers.insert("mcp-protocol-version", "1999-15-21"); - let json_rpc_message: ClientJsonrpcRequest = - ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + RequestFromClient::ListToolsRequest(None).into(), + ); let response = send_post_request( &server.streamable_url, @@ -1000,8 +1023,10 @@ async fn should_handle_protocol_version_validation_for_delete_requests() { async fn should_return_json_response_for_a_single_request() { let (server, session_id) = initialize_server(Some(true), None).await.unwrap(); - let json_rpc_message: ClientJsonrpcRequest = - ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + RequestFromClient::ListToolsRequest(None).into(), + ); let response = send_post_request( &server.streamable_url, @@ -1021,8 +1046,7 @@ async fn should_return_json_response_for_a_single_request() { let message = response.json::().await.unwrap(); - let ResultFromServer::ServerResult(ServerResult::ListToolsResult(result)) = message.result - else { + let ResultFromServer::ListToolsResult(result) = message.result else { panic!("invalid ListToolsResult") }; @@ -1046,23 +1070,31 @@ async fn should_return_json_response_for_a_batch_request() { let json_rpc_message_1: ClientJsonrpcRequest = ClientJsonrpcRequest::new( RequestId::String("req_1".to_string()), - ListToolsRequest::new(None).into(), + RequestFromClient::ListToolsRequest(None).into(), ); let mut map = Map::new(); map.insert("name".to_string(), Value::String("Ali".to_string())); let json_rpc_message_3: ClientJsonrpcRequest = ClientJsonrpcRequest::new( RequestId::String("req_2".to_string()), - CallToolRequest::new(CallToolRequestParams { + RequestFromClient::CallToolRequest(CallToolRequestParams { arguments: Some(map), name: "say_hello".to_string(), + meta: None, + task: None, }) .into(), ); let batch_message = ClientMessages::Batch(vec![ json_rpc_message_1.into(), - ClientMessage::from_message(RootsListChangedNotification::new(None), None).unwrap(), + ClientMessage::from_message( + MessageFromClient::NotificationFromClient( + NotificationFromClient::RootsListChangedNotification(None), + ), + None, + ) + .unwrap(), json_rpc_message_3.into(), ]); @@ -1096,9 +1128,7 @@ async fn should_return_json_response_for_a_batch_request() { result_1.request_id().unwrap(), RequestId::String("req_1".to_string()) ); - let ResultFromServer::ServerResult(ServerResult::ListToolsResult(_)) = - result_1.as_response().unwrap().result - else { + let ResultFromServer::ListToolsResult(_) = result_1.as_response().unwrap().result else { panic!("Expected a ListToolsResult"); }; @@ -1107,9 +1137,7 @@ async fn should_return_json_response_for_a_batch_request() { result_2.request_id().unwrap(), RequestId::String("req_2".to_string()) ); - let ResultFromServer::ServerResult(ServerResult::CallToolResult(_)) = - result_2.as_response().unwrap().result - else { + let ResultFromServer::CallToolResult(_) = result_2.as_response().unwrap().result else { panic!("Expected a CallToolResult"); }; @@ -1124,16 +1152,18 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { let json_rpc_message_1: ClientJsonrpcRequest = ClientJsonrpcRequest::new( RequestId::String("req_1".to_string()), - ListToolsRequest::new(None).into(), + RequestFromClient::ListToolsRequest(None).into(), ); let mut map = Map::new(); map.insert("name".to_string(), Value::String("Ali".to_string())); let json_rpc_message_2: ClientJsonrpcRequest = ClientJsonrpcRequest::new( RequestId::String("req_2".to_string()), - CallToolRequest::new(CallToolRequestParams { + RequestFromClient::CallToolRequest(CallToolRequestParams { arguments: Some(map), name: "say_hello".to_string(), + meta: None, + task: None, }) .into(), ); @@ -1171,9 +1201,7 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { result_1.request_id().unwrap(), RequestId::String("req_1".to_string()) ); - let ResultFromServer::ServerResult(ServerResult::ListToolsResult(_)) = - result_1.as_response().unwrap().result - else { + let ResultFromServer::ListToolsResult(_) = result_1.as_response().unwrap().result else { panic!("Expected a ListToolsResult"); }; @@ -1182,9 +1210,7 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { result_2.request_id().unwrap(), RequestId::String("req_2".to_string()) ); - let ResultFromServer::ServerResult(ServerResult::CallToolResult(_)) = - result_2.as_response().unwrap().result - else { + let ResultFromServer::CallToolResult(_) = result_2.as_response().unwrap().result else { panic!("Expected a CallToolResult"); }; } @@ -1452,6 +1478,7 @@ async fn should_store_and_include_event_ids_in_server_sse_messages() { data: json!("notification1"), level: LoggingLevel::Info, logger: None, + meta: None, }, ) .await; @@ -1464,6 +1491,7 @@ async fn should_store_and_include_event_ids_in_server_sse_messages() { data: json!("notification2"), level: LoggingLevel::Info, logger: None, + meta: None, }, ) .await; @@ -1477,10 +1505,7 @@ async fn should_store_and_include_event_ids_in_server_sse_messages() { let message: ServerJsonrpcNotification = serde_json::from_str(&data).unwrap(); - let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( - notification1, - )) = message.notification - else { + let ServerJsonrpcNotification::LoggingMessageNotification(notification1) = message else { panic!("invalid message received!"); }; @@ -1501,10 +1526,7 @@ async fn should_store_and_include_event_ids_in_server_sse_messages() { // deserialize the message returned by event_store let message: ServerJsonrpcNotification = serde_json::from_str(&events.messages[0]).unwrap(); - let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( - notification2, - )) = message.notification - else { + let ServerJsonrpcNotification::LoggingMessageNotification(notification2) = message else { panic!("invalid message in store!"); }; assert_eq!(notification2.params.data.as_str().unwrap(), "notification2"); @@ -1526,6 +1548,7 @@ async fn should_store_and_replay_mcp_server_tool_notifications() { data: json!("notification1"), level: LoggingLevel::Info, logger: None, + meta: None, }, ) .await; @@ -1537,10 +1560,7 @@ async fn should_store_and_replay_mcp_server_tool_notifications() { let message: ServerJsonrpcNotification = serde_json::from_str(&data).unwrap(); - let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( - notification1, - )) = message.notification - else { + let ServerJsonrpcNotification::LoggingMessageNotification(notification1) = message else { panic!("invalid message received!"); }; @@ -1561,6 +1581,7 @@ async fn should_store_and_replay_mcp_server_tool_notifications() { data: json!("notification2"), level: LoggingLevel::Info, logger: None, + meta: None, }, ) .await; @@ -1574,10 +1595,7 @@ async fn should_store_and_replay_mcp_server_tool_notifications() { assert_eq!(events.len(), 1); let message: ServerJsonrpcNotification = serde_json::from_str(&events[0].2).unwrap(); - let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( - notification1, - )) = message.notification - else { + let ServerJsonrpcNotification::LoggingMessageNotification(notification1) = message else { panic!("invalid message received!"); }; @@ -1661,11 +1679,12 @@ async fn should_call_a_tool_with_auth_info_when_authenticated() { let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( RequestId::Integer(1), - CallToolRequest::new(CallToolRequestParams { + RequestFromClient::CallToolRequest(CallToolRequestParams { arguments: None, name: "display_auth_info".to_string(), - }) - .into(), + meta: None, + task: None, + }), ); let response = send_post_request( @@ -1691,8 +1710,7 @@ async fn should_call_a_tool_with_auth_info_when_authenticated() { assert!(matches!(message.id, RequestId::Integer(1))); - let ResultFromServer::ServerResult(ServerResult::CallToolResult(result)) = message.result - else { + let ResultFromServer::CallToolResult(result) = message.result else { panic!("invalid CallToolResult") }; @@ -1708,6 +1726,95 @@ async fn should_call_a_tool_with_auth_info_when_authenticated() { .iter() .all(|s| ["mcp", "mcp:tools"].contains(&s.as_str().unwrap())),); } + +#[tokio::test] +async fn should_handle_elicitation() { + #[mcp_elicit(message = "Please enter your info", mode = form)] + #[derive(JsonSchema)] + pub struct UserInfo { + #[json_schema(title = "Name", min_length = 5, max_length = 100)] + pub name: String, + #[json_schema(title = "Email", format = "email")] + pub email: Option, + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + #[json_schema(title = "Tags")] + pub tags: Vec, + } + + common::init_tracing(); + let (server, session_id) = initialize_server(Some(false), None).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; + assert_eq!(response.status(), StatusCode::OK); + + let mut content: HashMap = HashMap::new(); + content.insert("name".to_string(), "Alice".into()); + content.insert("email".to_string(), "alice@Borderland.com".into()); + content.insert("age".to_string(), 25.into()); + content.insert("tags".to_string(), vec!["rust", "c++"].into()); + + let elicit_response: ElicitResult = ElicitResult { + action: rust_mcp_schema::ElicitResultAction::Accept, + content: Some(content), + meta: None, + }; + + let elicit_rpc_response = ClientJsonrpcResponse::new( + RequestId::Integer(0), + ResultFromClient::ElicitResult(elicit_response.clone()), + ); + + // send a response bac after 500ms delay, to simulate client response + let sid_clone = session_id.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(521)).await; + let _response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&elicit_rpc_response).unwrap(), + Some(&sid_clone), + None, + ) + .await + .expect("Request failed"); + }); + + server + .hyper_runtime + .request_elicitation(&session_id, UserInfo::elicit_request_params()) + .await + .unwrap(); + + let message_str = read_sse_event(response, 1) + .await + .unwrap() + .first() + .unwrap() + .2 + .clone(); + + let message: ServerRequest = serde_json::from_str(&message_str).unwrap(); + assert!(matches!(message, ServerRequest::ElicitRequest(_))); + + // create UserInfo from ElicitResult content + let user: UserInfo = UserInfo::from_elicit_result_content(elicit_response.content).unwrap(); + + assert_eq!(user.name, "Alice"); + assert_eq!(user.age, 25); + assert_eq!(user.tags, vec!["rust", "c++"]); + assert_eq!(user.email.as_ref().unwrap(), "alice@Borderland.com"); + + println!("name: {}", user.name); + println!("age: {}", user.age); + println!( + "email: {}", + user.email.clone().unwrap_or("not provider".into()) + ); + println!("tags: {}", user.tags.join(",")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap(); +} + // should return 400 error for invalid JSON-RPC messages // should keep stream open after sending server notifications // NA: should reject second initialization request diff --git a/crates/rust-mcp-transport/Cargo.toml b/crates/rust-mcp-transport/Cargo.toml index 753f644..e19499e 100644 --- a/crates/rust-mcp-transport/Cargo.toml +++ b/crates/rust-mcp-transport/Cargo.toml @@ -14,7 +14,7 @@ exclude = ["assets/"] [dependencies] -rust-mcp-schema = { workspace = true, default-features = false } +rust-mcp-schema = { workspace = true, default-features = false , features=["latest","schema_utils"]} async-trait = { workspace = true } futures = { workspace = true } @@ -43,18 +43,8 @@ workspace = true ### FEATURES ################################################################# [features] -default = ["stdio", "sse", "streamable-http", "2025_06_18"] # Default features +default = ["stdio", "sse", "streamable-http"] # Default features stdio = [] sse = ["reqwest"] streamable-http = ["reqwest"] - - -# enabled mcp protocol version 2025_06_18 -2025_06_18 = ["rust-mcp-schema/2025_06_18", "rust-mcp-schema/schema_utils"] - -# enabled mcp protocol version 2025_03_26 -2025_03_26 = ["rust-mcp-schema/2025_03_26", "rust-mcp-schema/schema_utils"] - -# enabled mcp protocol version 2024_11_05 -2024_11_05 = ["rust-mcp-schema/2024_11_05", "rust-mcp-schema/schema_utils"] diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index 0b10918..297b73b 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -119,6 +119,7 @@ impl MCPStream { line = lines_stream.next_line() =>{ match line { Ok(Some(line)) => { + tracing::debug!("raw payload: {}",line); // deserialize and send it to the stream let message: X = match serde_json::from_str(&line){ @@ -128,6 +129,7 @@ impl MCPStream { continue; }, }; + tx.send(message).await.map_err(GenericSendError::new)?; } Ok(None) => { diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index 62c591f..e42c63f 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -9,7 +9,7 @@ use crate::{ self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage, ServerMessages, }, - JsonrpcError, + JsonrpcErrorResponse, }, SessionId, StreamId, }; @@ -192,7 +192,7 @@ impl McpDispatch .filter(|message| message.is_request()) .map(|message| { ( - message.request_id().unwrap(), // guaranteed to have request_id + message.request_id(), self.store_pending_request_for_message(message), ) }) @@ -222,9 +222,9 @@ impl McpDispatch .zip(request_ids) .map(|(res, request_id)| match res { Ok(response) => response, - Err(error) => ServerMessage::Error(JsonrpcError::new( + Err(error) => ServerMessage::Error(JsonrpcErrorResponse::new( RpcError::internal_error().with_message(error.to_string()), - request_id.to_owned(), + request_id.cloned(), )), }) .collect(); @@ -334,7 +334,7 @@ impl McpDispatch .filter(|message| message.is_request()) .map(|message| { ( - message.request_id().unwrap(), // guaranteed to have request_id + message.request_id(), self.store_pending_request_for_message(message), ) }) @@ -364,9 +364,9 @@ impl McpDispatch .zip(request_ids) .map(|(res, request_id)| match res { Ok(response) => response, - Err(error) => ClientMessage::Error(JsonrpcError::new( + Err(error) => ClientMessage::Error(JsonrpcErrorResponse::new( RpcError::internal_error().with_message(error.to_string()), - request_id.to_owned(), + request_id.cloned(), )), }) .collect(); diff --git a/crates/rust-mcp-transport/src/schema.rs b/crates/rust-mcp-transport/src/schema.rs index 2c7e7b4..956ad46 100644 --- a/crates/rust-mcp-transport/src/schema.rs +++ b/crates/rust-mcp-transport/src/schema.rs @@ -1,14 +1 @@ -#[cfg(feature = "2025_06_18")] pub use rust_mcp_schema::*; - -#[cfg(all( - feature = "2025_03_26", - not(any(feature = "2024_11_05", feature = "2025_06_18")) -))] -pub use rust_mcp_schema::mcp_2025_03_26::*; - -#[cfg(all( - feature = "2024_11_05", - not(any(feature = "2025_03_26", feature = "2025_06_18")) -))] -pub use rust_mcp_schema::mcp_2024_11_05::*; diff --git a/doc/getting-started-mcp-server.md b/doc/getting-started-mcp-server.md index 418fd66..9fd7afa 100644 --- a/doc/getting-started-mcp-server.md +++ b/doc/getting-started-mcp-server.md @@ -175,11 +175,7 @@ impl ServerHandler for MyServerHandler { } //Handles incoming CallToolRequest and processes it using the appropriate tool. - async fn handle_call_tool_request( - &self, - request: CallToolRequest, - _runtime: Arc, - ) -> std::result::Result { + async fn handle_call_tool_request(&self, request: CallToolRequest) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = GreetingTools::try_from(request.params).map_err(CallToolError::new)?; @@ -216,7 +212,9 @@ use rust_mcp_sdk::schema::{ ServerCapabilitiesTools, }; use rust_mcp_sdk::{ - McpServer, StdioTransport, TransportOptions, error::SdkResult, mcp_server::server_runtime, + error::SdkResult, + mcp_server::{server_runtime, ServerRuntime, ToMcpServerHandler}, + McpServer, StdioTransport, TransportOptions, }; #[tokio::main] @@ -246,7 +244,7 @@ async fn main() -> SdkResult<()> { let handler = MyServerHandler {}; //create the MCP server - let server = server_runtime::create_server(server_details, transport, handler); + let server = server_runtime::create_server(server_details, transport, handler.to_mcp_server_handler()); // Start the server server.start().await diff --git a/examples/auth/server-oauth-remote/Cargo.toml b/examples/auth/server-oauth-remote/Cargo.toml index baf5bc2..1e1b933 100644 --- a/examples/auth/server-oauth-remote/Cargo.toml +++ b/examples/auth/server-oauth-remote/Cargo.toml @@ -14,7 +14,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "sse", "auth", "hyper-server", - "2025_06_18", ] } rust-mcp-extra={ workspace = true, features=["auth"]} diff --git a/examples/auth/server-oauth-remote/src/handler.rs b/examples/auth/server-oauth-remote/src/handler.rs index d137080..1dd7638 100644 --- a/examples/auth/server-oauth-remote/src/handler.rs +++ b/examples/auth/server-oauth-remote/src/handler.rs @@ -1,11 +1,10 @@ use async_trait::async_trait; use rust_mcp_sdk::auth::AuthInfo; use rust_mcp_sdk::macros::{mcp_tool, JsonSchema}; -use rust_mcp_sdk::schema::TextContent; use rust_mcp_sdk::schema::{ - schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, - ListToolsResult, RpcError, + schema_utils::CallToolError, CallToolResult, ListToolsResult, RpcError, }; +use rust_mcp_sdk::schema::{CallToolRequestParams, PaginatedRequestParams, TextContent}; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; use std::sync::Arc; use std::vec; @@ -42,7 +41,7 @@ impl ServerHandler for McpServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult async fn handle_list_tools_request( &self, - request: ListToolsRequest, + params: Option, runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { @@ -55,16 +54,16 @@ impl ServerHandler for McpServerHandler { /// Handles incoming CallToolRequest and processes it using the appropriate tool. async fn handle_call_tool_request( &self, - request: CallToolRequest, + params: CallToolRequestParams, runtime: Arc, ) -> std::result::Result { - if request.params.name.eq(&ShowAuthInfo::tool_name()) { + if params.name.eq(&ShowAuthInfo::tool_name()) { let tool = ShowAuthInfo::default(); tool.call_tool(runtime.auth_info_cloned().await) } else { Err(CallToolError::from_message(format!( "Tool \"{}\" does not exists or inactive!", - request.params.name, + params.name, ))) } } diff --git a/examples/auth/server-oauth-remote/src/main.rs b/examples/auth/server-oauth-remote/src/main.rs index e1d442e..ed1d047 100644 --- a/examples/auth/server-oauth-remote/src/main.rs +++ b/examples/auth/server-oauth-remote/src/main.rs @@ -12,6 +12,7 @@ use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, LATEST_PROTOCOL_VERSION, }; +use rust_mcp_sdk::{mcp_icon, ToMcpServerHandler}; use std::env; use std::sync::Arc; use std::time::Duration; @@ -94,9 +95,17 @@ async fn main() -> SdkResult<()> { let server_details = InitializeResult { // server name and version server_info: Implementation { - name: "Remote Oauth Test MCP Server".to_string(), - version: "0.1.0".to_string(), - title: Some("Remote Oauth Test MCP Server".to_string()), + name: "Remote Oauth Test MCP Server".into(), + version: "0.1.0".into(), + title: Some("Remote Oauth Test MCP Server".into()), + description: Some("Remote Oauth Test MCP Server, by Rust MCP SDK".into()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, capabilities: ServerCapabilities { // indicates that server support mcp tools @@ -104,8 +113,8 @@ async fn main() -> SdkResult<()> { ..Default::default() // Using default values for other fields }, meta: None, - instructions: Some("server instructions...".to_string()), - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + instructions: Some("server instructions...".into()), + protocol_version: LATEST_PROTOCOL_VERSION.into(), }; let handler = McpServerHandler {}; @@ -114,11 +123,11 @@ async fn main() -> SdkResult<()> { let server = hyper_server::create_server( server_details, - handler, + handler.to_mcp_server_handler(), HyperServerOptions { - host: "localhost".to_string(), + host: "localhost".into(), port: 3000, - custom_streamable_http_endpoint: Some("/".to_string()), + custom_streamable_http_endpoint: Some("/".into()), ping_interval: Duration::from_secs(5), event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability auth: Some(Arc::new(oauth_metadata_provider)), // enable authentication diff --git a/examples/hello-world-mcp-server-stdio-core/Cargo.toml b/examples/hello-world-mcp-server-stdio-core/Cargo.toml index 5667e11..ceeaacf 100644 --- a/examples/hello-world-mcp-server-stdio-core/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio-core/Cargo.toml @@ -11,7 +11,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", "stdio", - "2025_06_18", ] } tokio = { workspace = true } diff --git a/examples/hello-world-mcp-server-stdio-core/src/handler.rs b/examples/hello-world-mcp-server-stdio-core/src/handler.rs index acf55ea..039452b 100644 --- a/examples/hello-world-mcp-server-stdio-core/src/handler.rs +++ b/examples/hello-world-mcp-server-stdio-core/src/handler.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use rust_mcp_sdk::schema::{ schema_utils::{CallToolError, NotificationFromClient, RequestFromClient, ResultFromServer}, - ClientRequest, ListToolsResult, RpcError, + ListToolsResult, RpcError, }; use rust_mcp_sdk::{ mcp_server::{enforce_compatible_protocol_version, ServerHandlerCore}, @@ -28,60 +28,53 @@ impl ServerHandlerCore for MyServerHandler { ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { - //Handle client requests according to their specific type. - RequestFromClient::ClientRequest(client_request) => match client_request { - // Handle the initialization request - ClientRequest::InitializeRequest(initialize_request) => { - let mut server_info = runtime.server_info().to_owned(); + // Handle the initialization request + RequestFromClient::InitializeRequest(params) => { + let mut server_info = runtime.server_info().to_owned(); - if let Some(updated_protocol_version) = enforce_compatible_protocol_version( - &initialize_request.params.protocol_version, - &server_info.protocol_version, - ) - .map_err(|err| RpcError::internal_error().with_message(err.to_string()))? - { - server_info.protocol_version = initialize_request.params.protocol_version; - } - - return Ok(server_info.into()); + if let Some(updated_protocol_version) = enforce_compatible_protocol_version( + ¶ms.protocol_version, + &server_info.protocol_version, + ) + .map_err(|err| RpcError::internal_error().with_message(err.to_string()))? + { + server_info.protocol_version = params.protocol_version; } - // Handle ListToolsRequest, return list of available tools - ClientRequest::ListToolsRequest(_) => Ok(ListToolsResult { - meta: None, - next_cursor: None, - tools: GreetingTools::tools(), - } - .into()), + return Ok(server_info.into()); + } - // Handles incoming CallToolRequest and processes it using the appropriate tool. - ClientRequest::CallToolRequest(request) => { - let tool_name = request.tool_name().to_string(); + // Handle ListToolsRequest, return list of available tools + RequestFromClient::ListToolsRequest(_params) => Ok(ListToolsResult { + meta: None, + next_cursor: None, + tools: GreetingTools::tools(), + } + .into()), - // Attempt to convert request parameters into GreetingTools enum - let tool_params = GreetingTools::try_from(request.params) - .map_err(|_| CallToolError::unknown_tool(tool_name.clone()))?; + // Handles incoming CallToolRequest and processes it using the appropriate tool. + RequestFromClient::CallToolRequest(params) => { + let tool_name = params.name.to_string(); - // Match the tool variant and execute its corresponding logic - let result = match tool_params { - GreetingTools::SayHelloTool(say_hello_tool) => { - say_hello_tool.call_tool().map_err(|err| { - RpcError::internal_error().with_message(err.to_string()) - })? - } - GreetingTools::SayGoodbyeTool(say_goodbye_tool) => { - say_goodbye_tool.call_tool().map_err(|err| { - RpcError::internal_error().with_message(err.to_string()) - })? - } - }; - Ok(result.into()) - } + // Attempt to convert request parameters into GreetingTools enum + let tool_params = GreetingTools::try_from(params) + .map_err(|_| CallToolError::unknown_tool(tool_name.clone()))?; + + // Match the tool variant and execute its corresponding logic + let result = match tool_params { + GreetingTools::SayHelloTool(say_hello_tool) => say_hello_tool + .call_tool() + .map_err(|err| RpcError::internal_error().with_message(err.to_string()))?, + GreetingTools::SayGoodbyeTool(say_goodbye_tool) => say_goodbye_tool + .call_tool() + .map_err(|err| RpcError::internal_error().with_message(err.to_string()))?, + }; + Ok(result.into()) + } - // Return Method not found for any other requests - _ => Err(RpcError::method_not_found() - .with_message(format!("No handler is implemented for '{method_name}'.",))), - }, + // Return Method not found for any other requests + _ => Err(RpcError::method_not_found() + .with_message(format!("No handler is implemented for '{method_name}'.",))), // Handle custom requests RequestFromClient::CustomRequest(_) => Err(RpcError::method_not_found() .with_message("No handler is implemented for custom requests.".to_string())), diff --git a/examples/hello-world-mcp-server-stdio-core/src/main.rs b/examples/hello-world-mcp-server-stdio-core/src/main.rs index d410526..e2ca3b6 100644 --- a/examples/hello-world-mcp-server-stdio-core/src/main.rs +++ b/examples/hello-world-mcp-server-stdio-core/src/main.rs @@ -2,12 +2,16 @@ mod handler; mod tools; use handler::MyServerHandler; +use rust_mcp_sdk::mcp_icon; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, LATEST_PROTOCOL_VERSION, }; + use rust_mcp_sdk::{ - error::SdkResult, mcp_server::server_runtime_core, McpServer, StdioTransport, TransportOptions, + error::SdkResult, + mcp_server::{server_runtime_core, ToMcpServerHandlerCore}, + McpServer, StdioTransport, TransportOptions, }; #[tokio::main] @@ -16,9 +20,17 @@ async fn main() -> SdkResult<()> { let server_details = InitializeResult { // server name and version server_info: Implementation { - name: "Hello World MCP Server".to_string(), - version: "0.1.0".to_string(), - title: Some("Hello World MCP Server".to_string()), + name: "Hello World MCP Server".into(), + version: "0.1.0".into(), + title: Some("Hello World MCP Server".into()), + description: Some("Hello World MCP Server, by Rust MCP SDK".into()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, capabilities: ServerCapabilities { // indicates that server support mcp tools @@ -26,8 +38,8 @@ async fn main() -> SdkResult<()> { ..Default::default() // Using default values for other fields }, meta: None, - instructions: Some("server instructions...".to_string()), - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + instructions: Some("server instructions...".into()), + protocol_version: LATEST_PROTOCOL_VERSION.into(), }; // STEP 2: create a std transport with default options @@ -38,7 +50,11 @@ async fn main() -> SdkResult<()> { let handler = MyServerHandler {}; // STEP 4: create a MCP server - let server = server_runtime_core::create_server(server_details, transport, handler); + let server = server_runtime_core::create_server( + server_details, + transport, + handler.to_mcp_server_handler(), + ); // STEP 5: Start the server if let Err(start_error) = server.start().await { diff --git a/examples/hello-world-mcp-server-stdio/Cargo.toml b/examples/hello-world-mcp-server-stdio/Cargo.toml index 3c35af3..8ac01af 100644 --- a/examples/hello-world-mcp-server-stdio/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio/Cargo.toml @@ -11,7 +11,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", "stdio", - "2025_06_18", ] } tokio = { workspace = true } diff --git a/examples/hello-world-mcp-server-stdio/src/handler.rs b/examples/hello-world-mcp-server-stdio/src/handler.rs index 47925a0..e0b2944 100644 --- a/examples/hello-world-mcp-server-stdio/src/handler.rs +++ b/examples/hello-world-mcp-server-stdio/src/handler.rs @@ -1,8 +1,8 @@ use async_trait::async_trait; use rust_mcp_sdk::schema::{ - schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, - ListToolsResult, RpcError, + schema_utils::CallToolError, CallToolResult, ListToolsResult, RpcError, }; +use rust_mcp_sdk::schema::{CallToolRequestParams, PaginatedRequestParams}; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; use std::sync::Arc; @@ -20,7 +20,7 @@ impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult async fn handle_list_tools_request( &self, - request: ListToolsRequest, + params: Option, runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { @@ -33,12 +33,12 @@ impl ServerHandler for MyServerHandler { /// Handles incoming CallToolRequest and processes it using the appropriate tool. async fn handle_call_tool_request( &self, - request: CallToolRequest, + params: CallToolRequestParams, runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = - GreetingTools::try_from(request.params).map_err(CallToolError::new)?; + GreetingTools::try_from(params).map_err(CallToolError::new)?; // Match the tool variant and execute its corresponding logic match tool_params { diff --git a/examples/hello-world-mcp-server-stdio/src/main.rs b/examples/hello-world-mcp-server-stdio/src/main.rs index 9e5d2b3..6bac2be 100644 --- a/examples/hello-world-mcp-server-stdio/src/main.rs +++ b/examples/hello-world-mcp-server-stdio/src/main.rs @@ -1,36 +1,56 @@ mod handler; mod tools; - use handler::MyServerHandler; +use rust_mcp_sdk::mcp_icon; use rust_mcp_sdk::schema::{ - Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, - LATEST_PROTOCOL_VERSION, + Implementation, InitializeResult, ProtocolVersion, ServerCapabilities, + ServerCapabilitiesResources, ServerCapabilitiesTools, }; use rust_mcp_sdk::{ error::SdkResult, - mcp_server::{server_runtime, ServerRuntime}, + mcp_server::{server_runtime, ServerRuntime, ToMcpServerHandler}, McpServer, StdioTransport, TransportOptions, }; use std::sync::Arc; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; #[tokio::main] async fn main() -> SdkResult<()> { + // initialize tracing + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); // STEP 1: Define server details and capabilities let server_details = InitializeResult { // server name and version server_info: Implementation { - name: "Hello World MCP Server".to_string(), - version: "0.1.0".to_string(), - title: Some("Hello World MCP Server".to_string()), + name: "Hello World MCP Server".into(), + version: "0.1.0".into(), + title: Some("Hello World MCP Server".into()), + description: Some("Hello World MCP Server, by Rust MCP SDK".into()), + icons: vec![ + mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + ) + ], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, capabilities: ServerCapabilities { // indicates that server support mcp tools tools: Some(ServerCapabilitiesTools { list_changed: None }), + resources: Some(ServerCapabilitiesResources { list_changed: None, subscribe: None }), ..Default::default() // Using default values for other fields }, meta: None, - instructions: Some("server instructions...".to_string()), - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + instructions: Some("server instructions...".into()), + protocol_version: ProtocolVersion::V2025_11_25.into(), }; // STEP 2: create a std transport with default options @@ -41,7 +61,7 @@ async fn main() -> SdkResult<()> { // STEP 4: create a MCP server let server: Arc = - server_runtime::create_server(server_details, transport, handler); + server_runtime::create_server(server_details, transport, handler.to_mcp_server_handler()); // STEP 5: Start the server if let Err(start_error) = server.start().await { diff --git a/examples/hello-world-mcp-server-stdio/src/tools.rs b/examples/hello-world-mcp-server-stdio/src/tools.rs index dc14645..8dca93e 100644 --- a/examples/hello-world-mcp-server-stdio/src/tools.rs +++ b/examples/hello-world-mcp-server-stdio/src/tools.rs @@ -1,28 +1,6 @@ use rust_mcp_sdk::macros::JsonSchema; use rust_mcp_sdk::schema::{schema_utils::CallToolError, CallToolResult, TextContent}; use rust_mcp_sdk::{macros::mcp_tool, tool_box}; -// use rust_mcp_sdk::schema::RpcError; -// use std::str::FromStr; - -// // Simple enum with FromStr trait implemented -// #[derive(JsonSchema, Debug)] -// pub enum Colors { -// #[json_schema(title = "Green Color")] -// Green, -// #[json_schema(title = "Red Color")] -// Red, -// } -// impl FromStr for Colors { -// type Err = RpcError; - -// fn from_str(s: &str) -> Result { -// match s.to_lowercase().as_str() { -// "green" => Ok(Colors::Green), -// "red" => Ok(Colors::Red), -// _ => Err(RpcError::parse_error().with_message("Invalid color".to_string())), -// } -// } -// } //****************// // SayHelloTool // diff --git a/examples/hello-world-server-streamable-http-core/Cargo.toml b/examples/hello-world-server-streamable-http-core/Cargo.toml index e2a22a7..f76173b 100644 --- a/examples/hello-world-server-streamable-http-core/Cargo.toml +++ b/examples/hello-world-server-streamable-http-core/Cargo.toml @@ -13,7 +13,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "streamable-http", "sse", "hyper-server", - "2025_06_18", ] } tokio = { workspace = true } diff --git a/examples/hello-world-server-streamable-http-core/README.md b/examples/hello-world-server-streamable-http-core/README.md index 49af2c2..cbfc294 100644 --- a/examples/hello-world-server-streamable-http-core/README.md +++ b/examples/hello-world-server-streamable-http-core/README.md @@ -17,7 +17,8 @@ To disable the SSE transport, set the `sse_support` value in the `HyperServerOpt ```rs let server = - hyper_server_core::create_server(server_details, handler, + hyper_server_core::create_server(server_details, + handler.to_mcp_server_handler(), HyperServerOptions{ sse_support: false, // Disable SSE support Default::default() diff --git a/examples/hello-world-server-streamable-http-core/src/handler.rs b/examples/hello-world-server-streamable-http-core/src/handler.rs index 7941075..7795846 100644 --- a/examples/hello-world-server-streamable-http-core/src/handler.rs +++ b/examples/hello-world-server-streamable-http-core/src/handler.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use rust_mcp_sdk::schema::{ schema_utils::{CallToolError, NotificationFromClient, RequestFromClient, ResultFromServer}, - ClientRequest, ListToolsResult, RpcError, + ListToolsResult, RpcError, }; use rust_mcp_sdk::{ mcp_server::{enforce_compatible_protocol_version, ServerHandlerCore}, @@ -28,65 +28,59 @@ impl ServerHandlerCore for MyServerHandler { ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { - //Handle client requests according to their specific type. - RequestFromClient::ClientRequest(client_request) => match client_request { - // Handle the initialization request - ClientRequest::InitializeRequest(initialize_request) => { - let mut server_info = runtime.server_info().to_owned(); + // Handle the initialization request + RequestFromClient::InitializeRequest(params) => { + let mut server_info = runtime.server_info().to_owned(); - if let Some(updated_protocol_version) = enforce_compatible_protocol_version( - &initialize_request.params.protocol_version, - &server_info.protocol_version, - ) - .map_err(|err| { - tracing::error!( - "Incompatible protocol version :\nclient: {}\nserver: {}", - &initialize_request.params.protocol_version, - &server_info.protocol_version - ); - RpcError::internal_error().with_message(err.to_string()) - })? { - server_info.protocol_version = updated_protocol_version; - } - - return Ok(server_info.into()); - } - // Handle ListToolsRequest, return list of available tools - ClientRequest::ListToolsRequest(_) => Ok(ListToolsResult { - meta: None, - next_cursor: None, - tools: GreetingTools::tools(), + if let Some(updated_protocol_version) = enforce_compatible_protocol_version( + ¶ms.protocol_version, + &server_info.protocol_version, + ) + .map_err(|err| { + tracing::error!( + "Incompatible protocol version :\nclient: {}\nserver: {}", + ¶ms.protocol_version, + &server_info.protocol_version + ); + RpcError::internal_error().with_message(err.to_string()) + })? { + server_info.protocol_version = updated_protocol_version; } - .into()), - // Handles incoming CallToolRequest and processes it using the appropriate tool. - ClientRequest::CallToolRequest(request) => { - let tool_name = request.tool_name().to_string(); + return Ok(server_info.into()); + } + // Handle ListToolsRequest, return list of available tools + RequestFromClient::ListToolsRequest(_) => Ok(ListToolsResult { + meta: None, + next_cursor: None, + tools: GreetingTools::tools(), + } + .into()), - // Attempt to convert request parameters into GreetingTools enum - let tool_params = GreetingTools::try_from(request.params) - .map_err(|_| CallToolError::unknown_tool(tool_name.clone()))?; + // Handles incoming CallToolRequest and processes it using the appropriate tool. + RequestFromClient::CallToolRequest(params) => { + let tool_name = params.name.to_string(); - // Match the tool variant and execute its corresponding logic - let result = match tool_params { - GreetingTools::SayHelloTool(say_hello_tool) => { - say_hello_tool.call_tool().map_err(|err| { - RpcError::internal_error().with_message(err.to_string()) - })? - } - GreetingTools::SayGoodbyeTool(say_goodbye_tool) => { - say_goodbye_tool.call_tool().map_err(|err| { - RpcError::internal_error().with_message(err.to_string()) - })? - } - }; - Ok(result.into()) - } + // Attempt to convert request parameters into GreetingTools enum + let tool_params = GreetingTools::try_from(params) + .map_err(|_| CallToolError::unknown_tool(tool_name.clone()))?; + + // Match the tool variant and execute its corresponding logic + let result = match tool_params { + GreetingTools::SayHelloTool(say_hello_tool) => say_hello_tool + .call_tool() + .map_err(|err| RpcError::internal_error().with_message(err.to_string()))?, + GreetingTools::SayGoodbyeTool(say_goodbye_tool) => say_goodbye_tool + .call_tool() + .map_err(|err| RpcError::internal_error().with_message(err.to_string()))?, + }; + Ok(result.into()) + } + + // Return Method not found for any other requests + _ => Err(RpcError::method_not_found() + .with_message(format!("No handler is implemented for '{method_name}'.",))), - // Return Method not found for any other requests - _ => Err(RpcError::method_not_found() - .with_message(format!("No handler is implemented for '{method_name}'.",))), - }, // Handle custom requests RequestFromClient::CustomRequest(_) => Err(RpcError::method_not_found() .with_message("No handler is implemented for custom requests.".to_string())), diff --git a/examples/hello-world-server-streamable-http-core/src/main.rs b/examples/hello-world-server-streamable-http-core/src/main.rs index 81a6ae5..2adff69 100644 --- a/examples/hello-world-server-streamable-http-core/src/main.rs +++ b/examples/hello-world-server-streamable-http-core/src/main.rs @@ -5,13 +5,14 @@ use std::sync::Arc; use handler::MyServerHandler; use rust_mcp_sdk::event_store::InMemoryEventStore; +use rust_mcp_sdk::mcp_icon; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, LATEST_PROTOCOL_VERSION, }; use rust_mcp_sdk::{ error::SdkResult, - mcp_server::{hyper_server_core, HyperServerOptions}, + mcp_server::{hyper_server_core, HyperServerOptions, ToMcpServerHandlerCore}, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -28,9 +29,19 @@ async fn main() -> SdkResult<()> { let server_details = InitializeResult { // server name and version server_info: Implementation { - name: "Hello World MCP Server Streamable HTTP + SSE".to_string(), - version: "0.1.0".to_string(), - title: Some("Hello World MCP Server Streamable HTTP + SSE".to_string()), + name: "Hello World MCP Server Streamable HTTP + SSE".into(), + version: "0.1.0".into(), + title: Some("Hello World MCP Server Streamable HTTP + SSE".into()), + description: Some( + "Hello World MCP Server Streamable HTTP + SSE, by Rust MCP SDK".into(), + ), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, capabilities: ServerCapabilities { // indicates that server support mcp tools @@ -38,8 +49,8 @@ async fn main() -> SdkResult<()> { ..Default::default() // Using default values for other fields }, meta: None, - instructions: Some("server instructions...".to_string()), - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + instructions: Some("server instructions...".into()), + protocol_version: LATEST_PROTOCOL_VERSION.into(), }; // STEP 2: instantiate our custom handler for handling MCP messages @@ -48,7 +59,7 @@ async fn main() -> SdkResult<()> { // STEP 3: create a MCP server let server = hyper_server_core::create_server( server_details, - handler, + handler.to_mcp_server_handler(), HyperServerOptions { sse_support: true, event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index 0fe7962..aae110f 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -13,7 +13,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "streamable-http", "sse", "hyper-server", - "2025_06_18", ] } tokio = { workspace = true } diff --git a/examples/hello-world-server-streamable-http/README.md b/examples/hello-world-server-streamable-http/README.md index 7e3f3b6..aa20251 100644 --- a/examples/hello-world-server-streamable-http/README.md +++ b/examples/hello-world-server-streamable-http/README.md @@ -16,7 +16,8 @@ To disable the SSE transport, set the `sse_support` value in the `HyperServerOpt ```rs let server = - hyper_server_core::create_server(server_details, handler, + hyper_server_core::create_server(server_details, + handler.to_mcp_server_handler(), HyperServerOptions{ sse_support: false, // Disable SSE support Default::default() diff --git a/examples/hello-world-server-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http/src/handler.rs index 3939d86..5416a6d 100644 --- a/examples/hello-world-server-streamable-http/src/handler.rs +++ b/examples/hello-world-server-streamable-http/src/handler.rs @@ -1,9 +1,9 @@ use crate::tools::GreetingTools; use async_trait::async_trait; use rust_mcp_sdk::schema::{ - schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, - ListToolsResult, RpcError, + schema_utils::CallToolError, CallToolResult, ListToolsResult, RpcError, }; +use rust_mcp_sdk::schema::{CallToolRequestParams, PaginatedRequestParams}; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; use std::sync::Arc; // Custom Handler to handle MCP Messages @@ -18,7 +18,7 @@ impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult async fn handle_list_tools_request( &self, - request: ListToolsRequest, + params: Option, runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { @@ -31,12 +31,12 @@ impl ServerHandler for MyServerHandler { /// Handles incoming CallToolRequest and processes it using the appropriate tool. async fn handle_call_tool_request( &self, - request: CallToolRequest, + params: CallToolRequestParams, runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = - GreetingTools::try_from(request.params).map_err(CallToolError::new)?; + GreetingTools::try_from(params).map_err(CallToolError::new)?; // Match the tool variant and execute its corresponding logic match tool_params { diff --git a/examples/hello-world-server-streamable-http/src/main.rs b/examples/hello-world-server-streamable-http/src/main.rs index c4fd373..178a109 100644 --- a/examples/hello-world-server-streamable-http/src/main.rs +++ b/examples/hello-world-server-streamable-http/src/main.rs @@ -3,7 +3,8 @@ mod tools; use handler::MyServerHandler; use rust_mcp_sdk::event_store::InMemoryEventStore; -use rust_mcp_sdk::mcp_server::{hyper_server, HyperServerOptions}; +use rust_mcp_sdk::mcp_icon; +use rust_mcp_sdk::mcp_server::{hyper_server, HyperServerOptions, ToMcpServerHandler}; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, LATEST_PROTOCOL_VERSION, @@ -32,9 +33,17 @@ async fn main() -> SdkResult<()> { let server_details = InitializeResult { // server name and version server_info: Implementation { - name: "Hello World MCP Server SSE".to_string(), - version: "0.1.0".to_string(), - title: Some("Hello World MCP Server SSE".to_string()), + name: "Hello World MCP Server SSE".into(), + version: "0.1.0".into(), + title: Some("Hello World MCP Server SSE".into()), + description: Some("test server, by Rust MCP SDK".into()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, capabilities: ServerCapabilities { // indicates that server support mcp tools @@ -42,8 +51,8 @@ async fn main() -> SdkResult<()> { ..Default::default() // Using default values for other fields }, meta: None, - instructions: Some("server instructions...".to_string()), - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + instructions: Some("server instructions...".into()), + protocol_version: LATEST_PROTOCOL_VERSION.into(), }; // STEP 2: instantiate our custom handler for handling MCP messages @@ -52,9 +61,9 @@ async fn main() -> SdkResult<()> { // STEP 3: instantiate HyperServer, providing `server_details` , `handler` and HyperServerOptions let server = hyper_server::create_server( server_details, - handler, + handler.to_mcp_server_handler(), HyperServerOptions { - host: "127.0.0.1".to_string(), + host: "127.0.0.1".into(), ping_interval: Duration::from_secs(5), event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() diff --git a/examples/simple-mcp-client-sse-core/Cargo.toml b/examples/simple-mcp-client-sse-core/Cargo.toml index 0c185b7..0e01a76 100644 --- a/examples/simple-mcp-client-sse-core/Cargo.toml +++ b/examples/simple-mcp-client-sse-core/Cargo.toml @@ -11,7 +11,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", "sse", - "2025_06_18", ] } tokio = { workspace = true } diff --git a/examples/simple-mcp-client-sse-core/src/handler.rs b/examples/simple-mcp-client-sse-core/src/handler.rs index ab86e9e..1fb5af7 100644 --- a/examples/simple-mcp-client-sse-core/src/handler.rs +++ b/examples/simple-mcp-client-sse-core/src/handler.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use rust_mcp_sdk::schema::{ self, schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, - RpcError, ServerRequest, + RpcError, }; use rust_mcp_sdk::{mcp_client::ClientHandlerCore, McpClient}; pub struct MyClientHandler; @@ -18,23 +18,24 @@ impl ClientHandlerCore for MyClientHandler { _runtime: &dyn McpClient, ) -> std::result::Result { match request { - RequestFromServer::ServerRequest(server_request) => match server_request { - ServerRequest::PingRequest(_) => { - return Ok(schema::Result::default().into()); - } - ServerRequest::CreateMessageRequest(_create_message_request) => { - Err(RpcError::internal_error().with_message( - "CreateMessageRequest handler is not implemented".to_string(), - )) - } - ServerRequest::ListRootsRequest(_list_roots_request) => { - Err(RpcError::internal_error() - .with_message("ListRootsRequest handler is not implemented".to_string())) - } - ServerRequest::ElicitRequest(_elicit_request) => Err(RpcError::internal_error() - .with_message("ElicitRequest handler is not implemented".to_string())), - }, - RequestFromServer::CustomRequest(_value) => Err(RpcError::internal_error() + RequestFromServer::PingRequest(_) => { + return Ok(schema::Result::default().into()); + } + RequestFromServer::CreateMessageRequest(_) => Err(RpcError::internal_error() + .with_message("CreateMessageRequest handler is not implemented".to_string())), + RequestFromServer::ListRootsRequest(_) => Err(RpcError::internal_error() + .with_message("ListRootsRequest handler is not implemented".to_string())), + RequestFromServer::ElicitRequest(_) => Err(RpcError::internal_error() + .with_message("ElicitRequest handler is not implemented".to_string())), + RequestFromServer::GetTaskRequest(_) => Err(RpcError::internal_error() + .with_message("GetTaskRequest handler is not implemented".to_string())), + RequestFromServer::GetTaskPayloadRequest(_) => Err(RpcError::internal_error() + .with_message("GetTaskPayloadRequest handler is not implemented".to_string())), + RequestFromServer::CancelTaskRequest(_) => Err(RpcError::internal_error() + .with_message("CancelTaskRequest handler is not implemented".to_string())), + RequestFromServer::ListTasksRequest(_) => Err(RpcError::internal_error() + .with_message("ListTasksRequest handler is not implemented".to_string())), + RequestFromServer::CustomRequest(_) => Err(RpcError::internal_error() .with_message("CustomRequest handler is not implemented".to_string())), } } @@ -44,21 +45,7 @@ impl ClientHandlerCore for MyClientHandler { notification: NotificationFromServer, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { - if let NotificationFromServer::ServerNotification( - schema::ServerNotification::LoggingMessageNotification(logging_message_notification), - ) = notification - { - println!( - "Notification from server: {}", - logging_message_notification.params.data - ); - } else { - println!( - "A {} notification received from the server", - notification.method() - ); - }; - + println!("Notification from server: \"{}\"", notification.method()); Ok(()) } diff --git a/examples/simple-mcp-client-sse-core/src/inquiry_utils.rs b/examples/simple-mcp-client-sse-core/src/inquiry_utils.rs index a8e7c9c..e423e05 100644 --- a/examples/simple-mcp-client-sse-core/src/inquiry_utils.rs +++ b/examples/simple-mcp-client-sse-core/src/inquiry_utils.rs @@ -69,7 +69,7 @@ impl InquiryUtils { return Ok(()); } - let tools = self.client.list_tools(None).await?; + let tools = self.client.request_tool_list(None).await?; self.print_header("Tools"); self.print_list( tools @@ -93,7 +93,7 @@ impl InquiryUtils { return Ok(()); } - let prompts = self.client.list_prompts(None).await?; + let prompts = self.client.request_prompt_list(None).await?; self.print_header("Prompts"); self.print_list( @@ -117,7 +117,7 @@ impl InquiryUtils { return Ok(()); } - let resources = self.client.list_resources(None).await?; + let resources = self.client.request_resource_list(None).await?; self.print_header("Resources"); @@ -147,7 +147,7 @@ impl InquiryUtils { return Ok(()); } - let templates = self.client.list_resource_templates(None).await?; + let templates = self.client.request_resource_template_list(None).await?; self.print_header("Resource Templates"); @@ -185,9 +185,11 @@ impl InquiryUtils { // invoke the tool let result = self .client - .call_tool(CallToolRequestParams { + .request_tool_call(CallToolRequestParams { name: "add".to_string(), arguments: Some(params), + meta: None, + task: None, }) .await?; @@ -204,7 +206,7 @@ impl InquiryUtils { for ping_index in 1..=max_pings { print!("Ping the server ({ping_index} out of {max_pings})..."); std::io::stdout().flush().unwrap(); - let ping_result = self.client.ping(None).await; + let ping_result = self.client.ping(None, None).await; print!( "\rPing the server ({} out of {}) : {}", ping_index, diff --git a/examples/simple-mcp-client-sse-core/src/main.rs b/examples/simple-mcp-client-sse-core/src/main.rs index be8279b..80fdce7 100644 --- a/examples/simple-mcp-client-sse-core/src/main.rs +++ b/examples/simple-mcp-client-sse-core/src/main.rs @@ -8,9 +8,9 @@ use rust_mcp_sdk::error::SdkResult; use rust_mcp_sdk::mcp_client::client_runtime_core; use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, - LATEST_PROTOCOL_VERSION, + SetLevelRequestParams, LATEST_PROTOCOL_VERSION, }; -use rust_mcp_sdk::{ClientSseTransport, ClientSseTransportOptions, McpClient}; +use rust_mcp_sdk::{mcp_icon, ClientSseTransport, ClientSseTransportOptions, McpClient}; use std::sync::Arc; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -30,11 +30,20 @@ async fn main() -> SdkResult<()> { let client_details: InitializeRequestParams = InitializeRequestParams { capabilities: ClientCapabilities::default(), client_info: Implementation { - name: "simple-rust-mcp-client-core-sse".to_string(), - version: "0.1.0".to_string(), - title: Some("Simple Rust MCP Client (Core,SSE)".to_string()), + name: "simple-rust-mcp-client-core-sse".into(), + version: "0.1.0".into(), + title: Some("Simple Rust MCP Client (Core,SSE)".into()), + description: Some("Simple Rust MCP Client (Core,SSE), by Rust MCP SDK".into()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, protocol_version: LATEST_PROTOCOL_VERSION.into(), + meta: None, }; // Step2 : Create a transport, with options to launch/connect to a MCP Server @@ -80,7 +89,13 @@ async fn main() -> SdkResult<()> { utils.call_add_tool(100, 25).await?; // Set the log level - utils.client.set_logging_level(LoggingLevel::Debug).await?; + utils + .client + .request_set_logging_level(SetLevelRequestParams { + level: LoggingLevel::Debug, + meta: None, + }) + .await?; // Send 3 pings to the server, with a 2-second interval between each ping. utils.ping_n_times(3).await; diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index 1e154cc..a68fc22 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -12,7 +12,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "sse", "streamable-http", "macros", - "2025_06_18", ] } tokio = { workspace = true } diff --git a/examples/simple-mcp-client-sse/src/inquiry_utils.rs b/examples/simple-mcp-client-sse/src/inquiry_utils.rs index a8e7c9c..e423e05 100644 --- a/examples/simple-mcp-client-sse/src/inquiry_utils.rs +++ b/examples/simple-mcp-client-sse/src/inquiry_utils.rs @@ -69,7 +69,7 @@ impl InquiryUtils { return Ok(()); } - let tools = self.client.list_tools(None).await?; + let tools = self.client.request_tool_list(None).await?; self.print_header("Tools"); self.print_list( tools @@ -93,7 +93,7 @@ impl InquiryUtils { return Ok(()); } - let prompts = self.client.list_prompts(None).await?; + let prompts = self.client.request_prompt_list(None).await?; self.print_header("Prompts"); self.print_list( @@ -117,7 +117,7 @@ impl InquiryUtils { return Ok(()); } - let resources = self.client.list_resources(None).await?; + let resources = self.client.request_resource_list(None).await?; self.print_header("Resources"); @@ -147,7 +147,7 @@ impl InquiryUtils { return Ok(()); } - let templates = self.client.list_resource_templates(None).await?; + let templates = self.client.request_resource_template_list(None).await?; self.print_header("Resource Templates"); @@ -185,9 +185,11 @@ impl InquiryUtils { // invoke the tool let result = self .client - .call_tool(CallToolRequestParams { + .request_tool_call(CallToolRequestParams { name: "add".to_string(), arguments: Some(params), + meta: None, + task: None, }) .await?; @@ -204,7 +206,7 @@ impl InquiryUtils { for ping_index in 1..=max_pings { print!("Ping the server ({ping_index} out of {max_pings})..."); std::io::stdout().flush().unwrap(); - let ping_result = self.client.ping(None).await; + let ping_result = self.client.ping(None, None).await; print!( "\rPing the server ({} out of {}) : {}", ping_index, diff --git a/examples/simple-mcp-client-sse/src/main.rs b/examples/simple-mcp-client-sse/src/main.rs index 0a76caa..e9833d3 100644 --- a/examples/simple-mcp-client-sse/src/main.rs +++ b/examples/simple-mcp-client-sse/src/main.rs @@ -8,9 +8,9 @@ use rust_mcp_sdk::error::SdkResult; use rust_mcp_sdk::mcp_client::client_runtime; use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, - LATEST_PROTOCOL_VERSION, + SetLevelRequestParams, LATEST_PROTOCOL_VERSION, }; -use rust_mcp_sdk::{ClientSseTransport, ClientSseTransportOptions, McpClient}; +use rust_mcp_sdk::{mcp_icon, ClientSseTransport, ClientSseTransportOptions, McpClient}; use std::sync::Arc; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -32,11 +32,20 @@ async fn main() -> SdkResult<()> { let client_details: InitializeRequestParams = InitializeRequestParams { capabilities: ClientCapabilities::default(), client_info: Implementation { - name: "simple-rust-mcp-client-sse".to_string(), - version: "0.1.0".to_string(), - title: Some("Simple Rust MCP Client (SSE)".to_string()), + name: "simple-rust-mcp-client-sse".into(), + version: "0.1.0".into(), + title: Some("Simple Rust MCP Client (SSE)".into()), + description: Some("Simple Rust MCP Client (SSE) by Rust MCP SDK".into()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, protocol_version: LATEST_PROTOCOL_VERSION.into(), + meta: None, }; // Step2 : Create a transport, with options to launch/connect to a MCP Server @@ -83,7 +92,14 @@ async fn main() -> SdkResult<()> { utils.call_add_tool(100, 25).await?; // // Set the log level - match utils.client.set_logging_level(LoggingLevel::Debug).await { + match utils + .client + .request_set_logging_level(SetLevelRequestParams { + level: LoggingLevel::Debug, + meta: None, + }) + .await + { Ok(_) => println!("Log level is set to \"Debug\""), Err(err) => eprintln!("Error setting the Log level : {err}"), } diff --git a/examples/simple-mcp-client-stdio-core/Cargo.toml b/examples/simple-mcp-client-stdio-core/Cargo.toml index 4144cae..b124b9d 100644 --- a/examples/simple-mcp-client-stdio-core/Cargo.toml +++ b/examples/simple-mcp-client-stdio-core/Cargo.toml @@ -11,7 +11,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", "stdio", - "2025_06_18", ] } tokio = { workspace = true } diff --git a/examples/simple-mcp-client-stdio-core/src/handler.rs b/examples/simple-mcp-client-stdio-core/src/handler.rs index bd5e4fe..0c325ba 100644 --- a/examples/simple-mcp-client-stdio-core/src/handler.rs +++ b/examples/simple-mcp-client-stdio-core/src/handler.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use rust_mcp_sdk::schema::{ self, schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, - RpcError, ServerRequest, + RpcError, }; use rust_mcp_sdk::{mcp_client::ClientHandlerCore, McpClient}; pub struct MyClientHandler; @@ -18,24 +18,29 @@ impl ClientHandlerCore for MyClientHandler { _runtime: &dyn McpClient, ) -> std::result::Result { match request { - RequestFromServer::ServerRequest(server_request) => match server_request { - ServerRequest::PingRequest(_) => { - return Ok(schema::Result::default().into()); - } - ServerRequest::CreateMessageRequest(_create_message_request) => { - Err(RpcError::internal_error().with_message( - "CreateMessageRequest handler is not implemented".to_string(), - )) - } - ServerRequest::ListRootsRequest(_list_roots_request) => { - Err(RpcError::internal_error() - .with_message("ListRootsRequest handler is not implemented".to_string())) - } - ServerRequest::ElicitRequest(_elicit_request) => Err(RpcError::internal_error() - .with_message("ElicitRequest handler is not implemented".to_string())), - }, - RequestFromServer::CustomRequest(_value) => Err(RpcError::internal_error() + RequestFromServer::PingRequest(_) => { + return Ok(schema::Result::default().into()); + } + RequestFromServer::CreateMessageRequest(_create_message_request) => { + Err(RpcError::internal_error() + .with_message("CreateMessageRequest handler is not implemented".to_string())) + } + RequestFromServer::ListRootsRequest(_list_roots_request) => { + Err(RpcError::internal_error() + .with_message("ListRootsRequest handler is not implemented".to_string())) + } + RequestFromServer::ElicitRequest(_elicit_request) => Err(RpcError::internal_error() + .with_message("ElicitRequest handler is not implemented".to_string())), + RequestFromServer::CustomRequest(_request) => Err(RpcError::internal_error() .with_message("CustomRequest handler is not implemented".to_string())), + RequestFromServer::GetTaskRequest(_request) => Err(RpcError::internal_error() + .with_message("GetTaskRequest handler is not implemented".to_string())), + RequestFromServer::GetTaskPayloadRequest(_request) => Err(RpcError::internal_error() + .with_message("GetTaskPayloadRequest handler is not implemented".to_string())), + RequestFromServer::CancelTaskRequest(_request) => Err(RpcError::internal_error() + .with_message("CancelTaskRequest handler is not implemented".to_string())), + RequestFromServer::ListTasksRequest(_request) => Err(RpcError::internal_error() + .with_message("ListTasksRequest handler is not implemented".to_string())), } } diff --git a/examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs b/examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs index a8e7c9c..e423e05 100644 --- a/examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs +++ b/examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs @@ -69,7 +69,7 @@ impl InquiryUtils { return Ok(()); } - let tools = self.client.list_tools(None).await?; + let tools = self.client.request_tool_list(None).await?; self.print_header("Tools"); self.print_list( tools @@ -93,7 +93,7 @@ impl InquiryUtils { return Ok(()); } - let prompts = self.client.list_prompts(None).await?; + let prompts = self.client.request_prompt_list(None).await?; self.print_header("Prompts"); self.print_list( @@ -117,7 +117,7 @@ impl InquiryUtils { return Ok(()); } - let resources = self.client.list_resources(None).await?; + let resources = self.client.request_resource_list(None).await?; self.print_header("Resources"); @@ -147,7 +147,7 @@ impl InquiryUtils { return Ok(()); } - let templates = self.client.list_resource_templates(None).await?; + let templates = self.client.request_resource_template_list(None).await?; self.print_header("Resource Templates"); @@ -185,9 +185,11 @@ impl InquiryUtils { // invoke the tool let result = self .client - .call_tool(CallToolRequestParams { + .request_tool_call(CallToolRequestParams { name: "add".to_string(), arguments: Some(params), + meta: None, + task: None, }) .await?; @@ -204,7 +206,7 @@ impl InquiryUtils { for ping_index in 1..=max_pings { print!("Ping the server ({ping_index} out of {max_pings})..."); std::io::stdout().flush().unwrap(); - let ping_result = self.client.ping(None).await; + let ping_result = self.client.ping(None, None).await; print!( "\rPing the server ({} out of {}) : {}", ping_index, diff --git a/examples/simple-mcp-client-stdio-core/src/main.rs b/examples/simple-mcp-client-stdio-core/src/main.rs index c129239..7dcdeac 100644 --- a/examples/simple-mcp-client-stdio-core/src/main.rs +++ b/examples/simple-mcp-client-stdio-core/src/main.rs @@ -5,10 +5,10 @@ use handler::MyClientHandler; use inquiry_utils::InquiryUtils; use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, - LATEST_PROTOCOL_VERSION, + SetLevelRequestParams, LATEST_PROTOCOL_VERSION, }; use rust_mcp_sdk::{error::SdkResult, mcp_client::client_runtime_core}; -use rust_mcp_sdk::{McpClient, StdioTransport, TransportOptions}; +use rust_mcp_sdk::{mcp_icon, McpClient, StdioTransport, TransportOptions}; use std::sync::Arc; const MCP_SERVER_TO_LAUNCH: &str = "@modelcontextprotocol/server-everything"; @@ -19,11 +19,20 @@ async fn main() -> SdkResult<()> { let client_details: InitializeRequestParams = InitializeRequestParams { capabilities: ClientCapabilities::default(), client_info: Implementation { - name: "simple-rust-mcp-client-core".to_string(), + name: "simple-rust-mcp-client-core".into(), version: "0.1.0".into(), - title: Some("Simple Rust MCP Client Core".to_string()), + title: Some("Simple Rust MCP Client Core".into()), + description: Some("Hello World MCP Server, by Rust MCP SDK".into()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, protocol_version: LATEST_PROTOCOL_VERSION.into(), + meta: None, }; // Step2 : Create a transport, with options to launch/connect to a MCP Server @@ -74,7 +83,13 @@ async fn main() -> SdkResult<()> { utils.call_add_tool(100, 25).await?; // Set the log level - utils.client.set_logging_level(LoggingLevel::Debug).await?; + utils + .client + .request_set_logging_level(SetLevelRequestParams { + level: LoggingLevel::Debug, + meta: None, + }) + .await?; // Send 3 ping requests to the server, with a 2-second interval between each ping request. utils.ping_n_times(3).await; diff --git a/examples/simple-mcp-client-stdio/Cargo.toml b/examples/simple-mcp-client-stdio/Cargo.toml index 496efa5..dc2c9bd 100644 --- a/examples/simple-mcp-client-stdio/Cargo.toml +++ b/examples/simple-mcp-client-stdio/Cargo.toml @@ -11,7 +11,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", "stdio", - "2025_06_18", ] } tokio = { workspace = true } @@ -21,6 +20,8 @@ async-trait = { workspace = true } futures = { workspace = true } thiserror = { workspace = true } colored = "3.0.0" +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } [lints] workspace = true diff --git a/examples/simple-mcp-client-stdio/src/inquiry_utils.rs b/examples/simple-mcp-client-stdio/src/inquiry_utils.rs index a8e7c9c..e423e05 100644 --- a/examples/simple-mcp-client-stdio/src/inquiry_utils.rs +++ b/examples/simple-mcp-client-stdio/src/inquiry_utils.rs @@ -69,7 +69,7 @@ impl InquiryUtils { return Ok(()); } - let tools = self.client.list_tools(None).await?; + let tools = self.client.request_tool_list(None).await?; self.print_header("Tools"); self.print_list( tools @@ -93,7 +93,7 @@ impl InquiryUtils { return Ok(()); } - let prompts = self.client.list_prompts(None).await?; + let prompts = self.client.request_prompt_list(None).await?; self.print_header("Prompts"); self.print_list( @@ -117,7 +117,7 @@ impl InquiryUtils { return Ok(()); } - let resources = self.client.list_resources(None).await?; + let resources = self.client.request_resource_list(None).await?; self.print_header("Resources"); @@ -147,7 +147,7 @@ impl InquiryUtils { return Ok(()); } - let templates = self.client.list_resource_templates(None).await?; + let templates = self.client.request_resource_template_list(None).await?; self.print_header("Resource Templates"); @@ -185,9 +185,11 @@ impl InquiryUtils { // invoke the tool let result = self .client - .call_tool(CallToolRequestParams { + .request_tool_call(CallToolRequestParams { name: "add".to_string(), arguments: Some(params), + meta: None, + task: None, }) .await?; @@ -204,7 +206,7 @@ impl InquiryUtils { for ping_index in 1..=max_pings { print!("Ping the server ({ping_index} out of {max_pings})..."); std::io::stdout().flush().unwrap(); - let ping_result = self.client.ping(None).await; + let ping_result = self.client.ping(None, None).await; print!( "\rPing the server ({} out of {}) : {}", ping_index, diff --git a/examples/simple-mcp-client-stdio/src/main.rs b/examples/simple-mcp-client-stdio/src/main.rs index e2bb3ab..9bf8673 100644 --- a/examples/simple-mcp-client-stdio/src/main.rs +++ b/examples/simple-mcp-client-stdio/src/main.rs @@ -7,24 +7,42 @@ use rust_mcp_sdk::error::SdkResult; use rust_mcp_sdk::mcp_client::client_runtime; use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, - LATEST_PROTOCOL_VERSION, + SetLevelRequestParams, LATEST_PROTOCOL_VERSION, }; -use rust_mcp_sdk::{McpClient, StdioTransport, TransportOptions}; +use rust_mcp_sdk::{mcp_icon, McpClient, StdioTransport, TransportOptions}; use std::sync::Arc; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; const MCP_SERVER_TO_LAUNCH: &str = "@modelcontextprotocol/server-everything"; #[tokio::main] async fn main() -> SdkResult<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + // Step1 : Define client details and capabilities let client_details: InitializeRequestParams = InitializeRequestParams { capabilities: ClientCapabilities::default(), client_info: Implementation { - name: "simple-rust-mcp-client".to_string(), - version: "0.1.0".to_string(), - title: Some("Simple Rust MCP Client".to_string()), + name: "simple-rust-mcp-client".into(), + version: "0.1.0".into(), + title: Some("Simple Rust MCP Client".into()), + description: Some("Simple Rust MCP Client, by Rust MCP SDK".into()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, protocol_version: LATEST_PROTOCOL_VERSION.into(), + meta: None, }; // Step2 : Create a transport, with options to launch/connect to a MCP Server @@ -75,7 +93,13 @@ async fn main() -> SdkResult<()> { utils.call_add_tool(100, 25).await?; // Set the log level - utils.client.set_logging_level(LoggingLevel::Debug).await?; + utils + .client + .request_set_logging_level(SetLevelRequestParams { + level: LoggingLevel::Debug, + meta: None, + }) + .await?; // Send 3 pings to the server, with a 2-second interval between each ping. utils.ping_n_times(3).await; diff --git a/examples/simple-mcp-client-streamable-http-core/Cargo.toml b/examples/simple-mcp-client-streamable-http-core/Cargo.toml index 9ed9816..c669c40 100644 --- a/examples/simple-mcp-client-streamable-http-core/Cargo.toml +++ b/examples/simple-mcp-client-streamable-http-core/Cargo.toml @@ -11,7 +11,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", "streamable-http", - "2025_06_18", ] } tokio = { workspace = true } diff --git a/examples/simple-mcp-client-streamable-http-core/src/handler.rs b/examples/simple-mcp-client-streamable-http-core/src/handler.rs index ab86e9e..b16a1fd 100644 --- a/examples/simple-mcp-client-streamable-http-core/src/handler.rs +++ b/examples/simple-mcp-client-streamable-http-core/src/handler.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use rust_mcp_sdk::schema::{ self, schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, - RpcError, ServerRequest, + RpcError, }; use rust_mcp_sdk::{mcp_client::ClientHandlerCore, McpClient}; pub struct MyClientHandler; @@ -18,23 +18,24 @@ impl ClientHandlerCore for MyClientHandler { _runtime: &dyn McpClient, ) -> std::result::Result { match request { - RequestFromServer::ServerRequest(server_request) => match server_request { - ServerRequest::PingRequest(_) => { - return Ok(schema::Result::default().into()); - } - ServerRequest::CreateMessageRequest(_create_message_request) => { - Err(RpcError::internal_error().with_message( - "CreateMessageRequest handler is not implemented".to_string(), - )) - } - ServerRequest::ListRootsRequest(_list_roots_request) => { - Err(RpcError::internal_error() - .with_message("ListRootsRequest handler is not implemented".to_string())) - } - ServerRequest::ElicitRequest(_elicit_request) => Err(RpcError::internal_error() - .with_message("ElicitRequest handler is not implemented".to_string())), - }, - RequestFromServer::CustomRequest(_value) => Err(RpcError::internal_error() + RequestFromServer::PingRequest(_) => { + return Ok(schema::Result::default().into()); + } + RequestFromServer::CreateMessageRequest(_) => Err(RpcError::internal_error() + .with_message("CreateMessageRequest handler is not implemented".to_string())), + RequestFromServer::ListRootsRequest(_) => Err(RpcError::internal_error() + .with_message("ListRootsRequest handler is not implemented".to_string())), + RequestFromServer::ElicitRequest(_) => Err(RpcError::internal_error() + .with_message("ElicitRequest handler is not implemented".to_string())), + RequestFromServer::GetTaskRequest(_) => Err(RpcError::internal_error() + .with_message("GetTaskRequest handler is not implemented".to_string())), + RequestFromServer::GetTaskPayloadRequest(_) => Err(RpcError::internal_error() + .with_message("GetTaskPayloadRequest handler is not implemented".to_string())), + RequestFromServer::CancelTaskRequest(_) => Err(RpcError::internal_error() + .with_message("CancelTaskRequest handler is not implemented".to_string())), + RequestFromServer::ListTasksRequest(_) => Err(RpcError::internal_error() + .with_message("ListTasksRequest handler is not implemented".to_string())), + RequestFromServer::CustomRequest(_) => Err(RpcError::internal_error() .with_message("CustomRequest handler is not implemented".to_string())), } } @@ -44,20 +45,7 @@ impl ClientHandlerCore for MyClientHandler { notification: NotificationFromServer, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { - if let NotificationFromServer::ServerNotification( - schema::ServerNotification::LoggingMessageNotification(logging_message_notification), - ) = notification - { - println!( - "Notification from server: {}", - logging_message_notification.params.data - ); - } else { - println!( - "A {} notification received from the server", - notification.method() - ); - }; + println!("Notification from server: \"{}\"", notification.method()); Ok(()) } diff --git a/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs b/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs index a8e7c9c..e423e05 100644 --- a/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs +++ b/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs @@ -69,7 +69,7 @@ impl InquiryUtils { return Ok(()); } - let tools = self.client.list_tools(None).await?; + let tools = self.client.request_tool_list(None).await?; self.print_header("Tools"); self.print_list( tools @@ -93,7 +93,7 @@ impl InquiryUtils { return Ok(()); } - let prompts = self.client.list_prompts(None).await?; + let prompts = self.client.request_prompt_list(None).await?; self.print_header("Prompts"); self.print_list( @@ -117,7 +117,7 @@ impl InquiryUtils { return Ok(()); } - let resources = self.client.list_resources(None).await?; + let resources = self.client.request_resource_list(None).await?; self.print_header("Resources"); @@ -147,7 +147,7 @@ impl InquiryUtils { return Ok(()); } - let templates = self.client.list_resource_templates(None).await?; + let templates = self.client.request_resource_template_list(None).await?; self.print_header("Resource Templates"); @@ -185,9 +185,11 @@ impl InquiryUtils { // invoke the tool let result = self .client - .call_tool(CallToolRequestParams { + .request_tool_call(CallToolRequestParams { name: "add".to_string(), arguments: Some(params), + meta: None, + task: None, }) .await?; @@ -204,7 +206,7 @@ impl InquiryUtils { for ping_index in 1..=max_pings { print!("Ping the server ({ping_index} out of {max_pings})..."); std::io::stdout().flush().unwrap(); - let ping_result = self.client.ping(None).await; + let ping_result = self.client.ping(None, None).await; print!( "\rPing the server ({} out of {}) : {}", ping_index, diff --git a/examples/simple-mcp-client-streamable-http-core/src/main.rs b/examples/simple-mcp-client-streamable-http-core/src/main.rs index e1a5849..c875710 100644 --- a/examples/simple-mcp-client-streamable-http-core/src/main.rs +++ b/examples/simple-mcp-client-streamable-http-core/src/main.rs @@ -8,9 +8,9 @@ use rust_mcp_sdk::error::SdkResult; use rust_mcp_sdk::mcp_client::client_runtime_core; use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, - LATEST_PROTOCOL_VERSION, + SetLevelRequestParams, LATEST_PROTOCOL_VERSION, }; -use rust_mcp_sdk::{McpClient, RequestOptions, StreamableTransportOptions}; +use rust_mcp_sdk::{mcp_icon, McpClient, RequestOptions, StreamableTransportOptions}; use std::sync::Arc; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -31,16 +31,25 @@ async fn main() -> SdkResult<()> { let client_details: InitializeRequestParams = InitializeRequestParams { capabilities: ClientCapabilities::default(), client_info: Implementation { - name: "simple-rust-mcp-client-core-sse".to_string(), - version: "0.1.0".to_string(), - title: Some("Simple Rust MCP Client (Core,SSE)".to_string()), + name: "simple-rust-mcp-client-core-sse".into(), + version: "0.1.0".into(), + title: Some("Simple Rust MCP Client (Core,SSE)".into()), + description: Some("Simple Rust MCP Client (Core,SSE), by Rust MCP SDK".into()), + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: Some("https://github.com/rust-mcp-stack/rust-mcp-sdk".into()), }, protocol_version: LATEST_PROTOCOL_VERSION.into(), + meta: None, }; // Step 2: Create transport options to connect to an MCP server via Streamable HTTP. let transport_options = StreamableTransportOptions { - mcp_url: MCP_SERVER_URL.to_string(), + mcp_url: MCP_SERVER_URL.into(), request_options: RequestOptions { ..RequestOptions::default() }, @@ -85,7 +94,13 @@ async fn main() -> SdkResult<()> { utils.call_add_tool(100, 25).await?; // Set the log level - utils.client.set_logging_level(LoggingLevel::Debug).await?; + utils + .client + .request_set_logging_level(SetLevelRequestParams { + level: LoggingLevel::Debug, + meta: None, + }) + .await?; // Send 3 pings to the server, with a 2-second interval between each ping. utils.ping_n_times(3).await; diff --git a/examples/simple-mcp-client-streamable-http/Cargo.toml b/examples/simple-mcp-client-streamable-http/Cargo.toml index 42aa6a6..a83bf5c 100644 --- a/examples/simple-mcp-client-streamable-http/Cargo.toml +++ b/examples/simple-mcp-client-streamable-http/Cargo.toml @@ -11,7 +11,6 @@ rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "streamable-http", "macros", - "2025_06_18", ] } tokio = { workspace = true } diff --git a/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs b/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs index a8e7c9c..e423e05 100644 --- a/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs +++ b/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs @@ -69,7 +69,7 @@ impl InquiryUtils { return Ok(()); } - let tools = self.client.list_tools(None).await?; + let tools = self.client.request_tool_list(None).await?; self.print_header("Tools"); self.print_list( tools @@ -93,7 +93,7 @@ impl InquiryUtils { return Ok(()); } - let prompts = self.client.list_prompts(None).await?; + let prompts = self.client.request_prompt_list(None).await?; self.print_header("Prompts"); self.print_list( @@ -117,7 +117,7 @@ impl InquiryUtils { return Ok(()); } - let resources = self.client.list_resources(None).await?; + let resources = self.client.request_resource_list(None).await?; self.print_header("Resources"); @@ -147,7 +147,7 @@ impl InquiryUtils { return Ok(()); } - let templates = self.client.list_resource_templates(None).await?; + let templates = self.client.request_resource_template_list(None).await?; self.print_header("Resource Templates"); @@ -185,9 +185,11 @@ impl InquiryUtils { // invoke the tool let result = self .client - .call_tool(CallToolRequestParams { + .request_tool_call(CallToolRequestParams { name: "add".to_string(), arguments: Some(params), + meta: None, + task: None, }) .await?; @@ -204,7 +206,7 @@ impl InquiryUtils { for ping_index in 1..=max_pings { print!("Ping the server ({ping_index} out of {max_pings})..."); std::io::stdout().flush().unwrap(); - let ping_result = self.client.ping(None).await; + let ping_result = self.client.ping(None, None).await; print!( "\rPing the server ({} out of {}) : {}", ping_index, diff --git a/examples/simple-mcp-client-streamable-http/src/main.rs b/examples/simple-mcp-client-streamable-http/src/main.rs index 95d4d8d..1ad61a9 100644 --- a/examples/simple-mcp-client-streamable-http/src/main.rs +++ b/examples/simple-mcp-client-streamable-http/src/main.rs @@ -7,9 +7,9 @@ use rust_mcp_sdk::error::SdkResult; use rust_mcp_sdk::mcp_client::client_runtime; use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, - LATEST_PROTOCOL_VERSION, + SetLevelRequestParams, LATEST_PROTOCOL_VERSION, }; -use rust_mcp_sdk::{McpClient, RequestOptions, StreamableTransportOptions}; +use rust_mcp_sdk::{mcp_icon, McpClient, RequestOptions, StreamableTransportOptions}; use std::sync::Arc; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -31,16 +31,25 @@ async fn main() -> SdkResult<()> { let client_details: InitializeRequestParams = InitializeRequestParams { capabilities: ClientCapabilities::default(), client_info: Implementation { - name: "simple-rust-mcp-client-sse".to_string(), - version: "0.1.0".to_string(), - title: Some("Simple Rust MCP Client (SSE)".to_string()), + name: "simple-rust-mcp-client-sse".into(), + version: "0.1.0".into(), + title: Some("Simple Rust MCP Client (SSE)".into()), + description: None, + icons: vec![mcp_icon!( + src = "https://raw.githubusercontent.com/rust-mcp-stack/rust-mcp-sdk/main/assets/rust-mcp-icon.png", + mime_type = "image/png", + sizes = ["128x128"], + theme = "dark" + )], + website_url: None, }, protocol_version: LATEST_PROTOCOL_VERSION.into(), + meta: None, }; // Step 2: Create transport options to connect to an MCP server via Streamable HTTP. let transport_options = StreamableTransportOptions { - mcp_url: MCP_SERVER_URL.to_string(), + mcp_url: MCP_SERVER_URL.into(), request_options: RequestOptions { ..RequestOptions::default() }, @@ -86,7 +95,14 @@ async fn main() -> SdkResult<()> { utils.call_add_tool(100, 25).await?; // Set the log level - match utils.client.set_logging_level(LoggingLevel::Debug).await { + match utils + .client + .request_set_logging_level(SetLevelRequestParams { + level: LoggingLevel::Debug, + meta: None, + }) + .await + { Ok(_) => println!("Log level is set to \"Debug\""), Err(err) => eprintln!("Error setting the Log level : {err}"), }