From 98fe88924dba15f1192f3479732f47a34cb8d7be Mon Sep 17 00:00:00 2001 From: John Harrison Date: Tue, 3 Dec 2024 11:11:32 -0800 Subject: [PATCH 01/15] [lldb] For a host socket, add a method to print the listening address. (#118330) This is most useful if you are listening on an address like 'localhost:0' and want to know the resolved ip + port of the socket listener. (cherry picked from commit 384562495bae44be053c1bbd40c359ef4b82d803) --- lldb/include/lldb/Host/Socket.h | 6 ++++++ lldb/include/lldb/Host/common/TCPSocket.h | 4 ++++ lldb/include/lldb/Host/posix/DomainSocket.h | 4 ++++ lldb/source/Host/common/TCPSocket.cpp | 8 ++++++++ lldb/source/Host/posix/DomainSocket.cpp | 14 ++++++++++++++ lldb/unittests/Host/SocketTest.cpp | 20 +++++++++++++++++++- 6 files changed, 55 insertions(+), 1 deletion(-) diff --git a/lldb/include/lldb/Host/Socket.h b/lldb/include/lldb/Host/Socket.h index 304a91bdf6741..04233154aad34 100644 --- a/lldb/include/lldb/Host/Socket.h +++ b/lldb/include/lldb/Host/Socket.h @@ -11,6 +11,7 @@ #include #include +#include #include "lldb/lldb-private.h" @@ -132,6 +133,11 @@ class Socket : public IOObject { // If this Socket is connected then return the URI used to connect. virtual std::string GetRemoteConnectionURI() const { return ""; }; + // If the Socket is listening then return the URI for clients to connect. + virtual std::vector GetListeningConnectionURI() const { + return {}; + } + protected: Socket(SocketProtocol protocol, bool should_close, bool m_child_process_inherit); diff --git a/lldb/include/lldb/Host/common/TCPSocket.h b/lldb/include/lldb/Host/common/TCPSocket.h index 78e80568e3996..f38fd50e5b2b3 100644 --- a/lldb/include/lldb/Host/common/TCPSocket.h +++ b/lldb/include/lldb/Host/common/TCPSocket.h @@ -13,6 +13,8 @@ #include "lldb/Host/Socket.h" #include "lldb/Host/SocketAddress.h" #include +#include +#include namespace lldb_private { class TCPSocket : public Socket { @@ -59,6 +61,8 @@ class TCPSocket : public Socket { std::string GetRemoteConnectionURI() const override; + std::vector GetListeningConnectionURI() const override; + private: TCPSocket(NativeSocket socket, const TCPSocket &listen_socket); diff --git a/lldb/include/lldb/Host/posix/DomainSocket.h b/lldb/include/lldb/Host/posix/DomainSocket.h index 35c33811f60de..d625b9a2aeb82 100644 --- a/lldb/include/lldb/Host/posix/DomainSocket.h +++ b/lldb/include/lldb/Host/posix/DomainSocket.h @@ -10,6 +10,8 @@ #define LLDB_HOST_POSIX_DOMAINSOCKET_H #include "lldb/Host/Socket.h" +#include +#include namespace lldb_private { class DomainSocket : public Socket { @@ -22,6 +24,8 @@ class DomainSocket : public Socket { std::string GetRemoteConnectionURI() const override; + std::vector GetListeningConnectionURI() const override; + protected: DomainSocket(SocketProtocol protocol, bool child_processes_inherit); diff --git a/lldb/source/Host/common/TCPSocket.cpp b/lldb/source/Host/common/TCPSocket.cpp index 1f31190b02f97..3c293ee428837 100644 --- a/lldb/source/Host/common/TCPSocket.cpp +++ b/lldb/source/Host/common/TCPSocket.cpp @@ -137,6 +137,14 @@ std::string TCPSocket::GetRemoteConnectionURI() const { return ""; } +std::vector TCPSocket::GetListeningConnectionURI() const { + std::vector URIs; + for (const auto &[fd, addr] : m_listen_sockets) + URIs.emplace_back(llvm::formatv("connection://[{0}]:{1}", + addr.GetIPAddress(), addr.GetPort())); + return URIs; +} + Status TCPSocket::CreateSocket(int domain) { Status error; if (IsValid()) diff --git a/lldb/source/Host/posix/DomainSocket.cpp b/lldb/source/Host/posix/DomainSocket.cpp index 2d18995c3bb46..1804c65de0ab5 100644 --- a/lldb/source/Host/posix/DomainSocket.cpp +++ b/lldb/source/Host/posix/DomainSocket.cpp @@ -155,3 +155,17 @@ std::string DomainSocket::GetRemoteConnectionURI() const { "{0}://{1}", GetNameOffset() == 0 ? "unix-connect" : "unix-abstract-connect", name); } + +std::vector DomainSocket::GetListeningConnectionURI() const { + if (m_socket == kInvalidSocketValue) + return {}; + + struct sockaddr_un addr; + bzero(&addr, sizeof(struct sockaddr_un)); + addr.sun_family = AF_UNIX; + socklen_t addr_len = sizeof(struct sockaddr_un); + if (::getsockname(m_socket, (struct sockaddr *)&addr, &addr_len) != 0) + return {}; + + return {llvm::formatv("unix-connect://{0}", addr.sun_path)}; +} diff --git a/lldb/unittests/Host/SocketTest.cpp b/lldb/unittests/Host/SocketTest.cpp index 3a356d11ba1a5..663a108f3037f 100644 --- a/lldb/unittests/Host/SocketTest.cpp +++ b/lldb/unittests/Host/SocketTest.cpp @@ -168,12 +168,30 @@ TEST_P(SocketTest, TCPListen0GetPort) { if (!HostSupportsIPv4()) return; llvm::Expected> sock = - Socket::TcpListen("10.10.12.3:0", false); + Socket::TcpListen("10.10.12.3:0", 5); ASSERT_THAT_EXPECTED(sock, llvm::Succeeded()); ASSERT_TRUE(sock.get()->IsValid()); EXPECT_NE(sock.get()->GetLocalPortNumber(), 0); } +TEST_P(SocketTest, TCPListen0GetListeningConnectionURI) { + if (!HostSupportsProtocol()) + return; + + std::string addr = llvm::formatv("[{0}]:0", GetParam().localhost_ip).str(); + llvm::Expected> sock = + Socket::TcpListen(addr, false); + ASSERT_THAT_EXPECTED(sock, llvm::Succeeded()); + ASSERT_TRUE(sock.get()->IsValid()); + + EXPECT_THAT( + sock.get()->GetListeningConnectionURI(), + testing::ElementsAre(llvm::formatv("connection://[{0}]:{1}", + GetParam().localhost_ip, + sock->get()->GetLocalPortNumber()) + .str())); +} + TEST_P(SocketTest, TCPGetConnectURI) { std::unique_ptr socket_a_up; std::unique_ptr socket_b_up; From df5390103765d15e33359daf9530f67bcf3900ff Mon Sep 17 00:00:00 2001 From: Pavel Labath Date: Fri, 13 Sep 2024 12:56:52 +0200 Subject: [PATCH 02/15] [lldb] Add a MainLoop version of DomainSocket::Accept (#108188) To go along with the existing TCPSocket implementation. (cherry picked from commit ebbc9ed2d60cacffc87232dc32374a2b38b92175) --- lldb/include/lldb/Host/Socket.h | 13 ++++++- lldb/include/lldb/Host/common/TCPSocket.h | 10 +---- lldb/include/lldb/Host/common/UDPSocket.h | 8 +++- lldb/include/lldb/Host/posix/DomainSocket.h | 8 +++- lldb/source/Host/common/Socket.cpp | 14 +++++++ lldb/source/Host/common/TCPSocket.cpp | 19 ++-------- lldb/source/Host/common/UDPSocket.cpp | 4 -- lldb/source/Host/posix/DomainSocket.cpp | 42 +++++++++++++++++---- lldb/unittests/Host/SocketTest.cpp | 41 +++++++++++++++++++- 9 files changed, 118 insertions(+), 41 deletions(-) diff --git a/lldb/include/lldb/Host/Socket.h b/lldb/include/lldb/Host/Socket.h index 04233154aad34..c982b6d9af9a1 100644 --- a/lldb/include/lldb/Host/Socket.h +++ b/lldb/include/lldb/Host/Socket.h @@ -13,6 +13,7 @@ #include #include +#include "lldb/Host/MainLoopBase.h" #include "lldb/lldb-private.h" #include "lldb/Host/SocketAddress.h" @@ -98,7 +99,17 @@ class Socket : public IOObject { virtual Status Connect(llvm::StringRef name) = 0; virtual Status Listen(llvm::StringRef name, int backlog) = 0; - virtual Status Accept(Socket *&socket) = 0; + + // Use the provided main loop instance to accept new connections. The callback + // will be called (from MainLoop::Run) for each new connection. This function + // does not block. + virtual llvm::Expected> + Accept(MainLoopBase &loop, + std::function socket)> sock_cb) = 0; + + // Accept a single connection and "return" it in the pointer argument. This + // function blocks until the connection arrives. + virtual Status Accept(Socket *&socket); // Initialize a Tcp Socket object in listening mode. listen and accept are // implemented separately because the caller may wish to manipulate or query diff --git a/lldb/include/lldb/Host/common/TCPSocket.h b/lldb/include/lldb/Host/common/TCPSocket.h index f38fd50e5b2b3..a37ae843bed23 100644 --- a/lldb/include/lldb/Host/common/TCPSocket.h +++ b/lldb/include/lldb/Host/common/TCPSocket.h @@ -44,16 +44,10 @@ class TCPSocket : public Socket { Status Connect(llvm::StringRef name) override; Status Listen(llvm::StringRef name, int backlog) override; - // Use the provided main loop instance to accept new connections. The callback - // will be called (from MainLoop::Run) for each new connection. This function - // does not block. + using Socket::Accept; llvm::Expected> Accept(MainLoopBase &loop, - std::function socket)> sock_cb); - - // Accept a single connection and "return" it in the pointer argument. This - // function blocks until the connection arrives. - Status Accept(Socket *&conn_socket) override; + std::function socket)> sock_cb) override; Status CreateSocket(int domain); diff --git a/lldb/include/lldb/Host/common/UDPSocket.h b/lldb/include/lldb/Host/common/UDPSocket.h index bae707e345d87..7348010d02ada 100644 --- a/lldb/include/lldb/Host/common/UDPSocket.h +++ b/lldb/include/lldb/Host/common/UDPSocket.h @@ -27,7 +27,13 @@ class UDPSocket : public Socket { size_t Send(const void *buf, const size_t num_bytes) override; Status Connect(llvm::StringRef name) override; Status Listen(llvm::StringRef name, int backlog) override; - Status Accept(Socket *&socket) override; + + llvm::Expected> + Accept(MainLoopBase &loop, + std::function socket)> sock_cb) override { + return llvm::errorCodeToError( + std::make_error_code(std::errc::operation_not_supported)); + } SocketAddress m_sockaddr; }; diff --git a/lldb/include/lldb/Host/posix/DomainSocket.h b/lldb/include/lldb/Host/posix/DomainSocket.h index d625b9a2aeb82..3a7fb16d3fd75 100644 --- a/lldb/include/lldb/Host/posix/DomainSocket.h +++ b/lldb/include/lldb/Host/posix/DomainSocket.h @@ -16,11 +16,17 @@ namespace lldb_private { class DomainSocket : public Socket { public: + DomainSocket(NativeSocket socket, bool should_close, + bool child_processes_inherit); DomainSocket(bool should_close, bool child_processes_inherit); Status Connect(llvm::StringRef name) override; Status Listen(llvm::StringRef name, int backlog) override; - Status Accept(Socket *&socket) override; + + using Socket::Accept; + llvm::Expected> + Accept(MainLoopBase &loop, + std::function socket)> sock_cb) override; std::string GetRemoteConnectionURI() const override; diff --git a/lldb/source/Host/common/Socket.cpp b/lldb/source/Host/common/Socket.cpp index 1a506aa95b246..2ed6e30cc1566 100644 --- a/lldb/source/Host/common/Socket.cpp +++ b/lldb/source/Host/common/Socket.cpp @@ -10,6 +10,7 @@ #include "lldb/Host/Config.h" #include "lldb/Host/Host.h" +#include "lldb/Host/MainLoop.h" #include "lldb/Host/SocketAddress.h" #include "lldb/Host/common/TCPSocket.h" #include "lldb/Host/common/UDPSocket.h" @@ -443,6 +444,19 @@ NativeSocket Socket::CreateSocket(const int domain, const int type, return sock; } +Status Socket::Accept(Socket *&socket) { + MainLoop accept_loop; + llvm::Expected> expected_handles = + Accept(accept_loop, + [&accept_loop, &socket](std::unique_ptr sock) { + socket = sock.release(); + accept_loop.RequestTermination(); + }); + if (!expected_handles) + return Status::FromError(expected_handles.takeError()); + return accept_loop.Run(); +} + NativeSocket Socket::AcceptSocket(NativeSocket sockfd, struct sockaddr *addr, socklen_t *addrlen, bool child_processes_inherit, Status &error) { diff --git a/lldb/source/Host/common/TCPSocket.cpp b/lldb/source/Host/common/TCPSocket.cpp index 3c293ee428837..fef71810df202 100644 --- a/lldb/source/Host/common/TCPSocket.cpp +++ b/lldb/source/Host/common/TCPSocket.cpp @@ -263,9 +263,9 @@ void TCPSocket::CloseListenSockets() { m_listen_sockets.clear(); } -llvm::Expected> TCPSocket::Accept( - MainLoopBase &loop, - std::function socket)> sock_cb) { +llvm::Expected> +TCPSocket::Accept(MainLoopBase &loop, + std::function socket)> sock_cb) { if (m_listen_sockets.size() == 0) return llvm::createStringError("No open listening sockets!"); @@ -309,19 +309,6 @@ llvm::Expected> TCPSocket::Accept( return handles; } -Status TCPSocket::Accept(Socket *&conn_socket) { - MainLoop accept_loop; - llvm::Expected> expected_handles = - Accept(accept_loop, - [&accept_loop, &conn_socket](std::unique_ptr sock) { - conn_socket = sock.release(); - accept_loop.RequestTermination(); - }); - if (!expected_handles) - return Status::FromError(expected_handles.takeError()); - return accept_loop.Run(); -} - int TCPSocket::SetOptionNoDelay() { return SetOption(IPPROTO_TCP, TCP_NODELAY, 1); } diff --git a/lldb/source/Host/common/UDPSocket.cpp b/lldb/source/Host/common/UDPSocket.cpp index 2a7a6cff414b1..05d7b2e650602 100644 --- a/lldb/source/Host/common/UDPSocket.cpp +++ b/lldb/source/Host/common/UDPSocket.cpp @@ -47,10 +47,6 @@ Status UDPSocket::Listen(llvm::StringRef name, int backlog) { return Status::FromErrorStringWithFormat("%s", g_not_supported_error); } -Status UDPSocket::Accept(Socket *&socket) { - return Status::FromErrorStringWithFormat("%s", g_not_supported_error); -} - llvm::Expected> UDPSocket::Connect(llvm::StringRef name, bool child_processes_inherit) { std::unique_ptr socket; diff --git a/lldb/source/Host/posix/DomainSocket.cpp b/lldb/source/Host/posix/DomainSocket.cpp index 1804c65de0ab5..6822932274b31 100644 --- a/lldb/source/Host/posix/DomainSocket.cpp +++ b/lldb/source/Host/posix/DomainSocket.cpp @@ -7,11 +7,13 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/posix/DomainSocket.h" +#include "lldb/Utility/LLDBLog.h" #include "llvm/Support/Errno.h" #include "llvm/Support/FileSystem.h" #include +#include #include #include @@ -57,7 +59,14 @@ static bool SetSockAddr(llvm::StringRef name, const size_t name_offset, } DomainSocket::DomainSocket(bool should_close, bool child_processes_inherit) - : Socket(ProtocolUnixDomain, should_close, child_processes_inherit) {} + : DomainSocket(kInvalidSocketValue, should_close, child_processes_inherit) { +} + +DomainSocket::DomainSocket(NativeSocket socket, bool should_close, + bool child_processes_inherit) + : Socket(ProtocolUnixDomain, should_close, child_processes_inherit) { + m_socket = socket; +} DomainSocket::DomainSocket(SocketProtocol protocol, bool child_processes_inherit) @@ -108,14 +117,31 @@ Status DomainSocket::Listen(llvm::StringRef name, int backlog) { return error; } -Status DomainSocket::Accept(Socket *&socket) { - Status error; - auto conn_fd = AcceptSocket(GetNativeSocket(), nullptr, nullptr, - m_child_processes_inherit, error); - if (error.Success()) - socket = new DomainSocket(conn_fd, *this); +llvm::Expected> DomainSocket::Accept( + MainLoopBase &loop, + std::function socket)> sock_cb) { + // TODO: Refactor MainLoop to avoid the shared_ptr requirement. + auto io_sp = std::make_shared(GetNativeSocket(), false, + m_child_processes_inherit); + auto cb = [this, sock_cb](MainLoopBase &loop) { + Log *log = GetLog(LLDBLog::Host); + Status error; + auto conn_fd = AcceptSocket(GetNativeSocket(), nullptr, nullptr, + m_child_processes_inherit, error); + if (error.Fail()) { + LLDB_LOG(log, "AcceptSocket({0}): {1}", GetNativeSocket(), error); + return; + } + std::unique_ptr sock_up(new DomainSocket(conn_fd, *this)); + sock_cb(std::move(sock_up)); + }; - return error; + Status error; + std::vector handles; + handles.emplace_back(loop.RegisterReadObject(io_sp, cb, error)); + if (error.Fail()) + return error.ToError(); + return handles; } size_t DomainSocket::GetNameOffset() const { return 0; } diff --git a/lldb/unittests/Host/SocketTest.cpp b/lldb/unittests/Host/SocketTest.cpp index 663a108f3037f..6b5efe5110a75 100644 --- a/lldb/unittests/Host/SocketTest.cpp +++ b/lldb/unittests/Host/SocketTest.cpp @@ -85,6 +85,43 @@ TEST_P(SocketTest, DomainListenConnectAccept) { std::unique_ptr socket_b_up; CreateDomainConnectedSockets(Path, &socket_a_up, &socket_b_up); } + +TEST_P(SocketTest, DomainMainLoopAccept) { + llvm::SmallString<64> Path; + std::error_code EC = + llvm::sys::fs::createUniqueDirectory("DomainListenConnectAccept", Path); + ASSERT_FALSE(EC); + llvm::sys::path::append(Path, "test"); + + // Skip the test if the $TMPDIR is too long to hold a domain socket. + if (Path.size() > 107u) + return; + + auto listen_socket_up = std::make_unique( + /*should_close=*/true, /*child_process_inherit=*/false); + Status error = listen_socket_up->Listen(Path, 5); + ASSERT_THAT_ERROR(error.ToError(), llvm::Succeeded()); + ASSERT_TRUE(listen_socket_up->IsValid()); + + MainLoop loop; + std::unique_ptr accepted_socket_up; + auto expected_handles = listen_socket_up->Accept( + loop, [&accepted_socket_up, &loop](std::unique_ptr sock_up) { + accepted_socket_up = std::move(sock_up); + loop.RequestTermination(); + }); + ASSERT_THAT_EXPECTED(expected_handles, llvm::Succeeded()); + + auto connect_socket_up = std::make_unique( + /*should_close=*/true, /*child_process_inherit=*/false); + ASSERT_THAT_ERROR(connect_socket_up->Connect(Path).ToError(), + llvm::Succeeded()); + ASSERT_TRUE(connect_socket_up->IsValid()); + + loop.Run(); + ASSERT_TRUE(accepted_socket_up); + ASSERT_TRUE(accepted_socket_up->IsValid()); +} #endif TEST_P(SocketTest, TCPListen0ConnectAccept) { @@ -109,9 +146,9 @@ TEST_P(SocketTest, TCPMainLoopAccept) { ASSERT_TRUE(listen_socket_up->IsValid()); MainLoop loop; - std::unique_ptr accepted_socket_up; + std::unique_ptr accepted_socket_up; auto expected_handles = listen_socket_up->Accept( - loop, [&accepted_socket_up, &loop](std::unique_ptr sock_up) { + loop, [&accepted_socket_up, &loop](std::unique_ptr sock_up) { accepted_socket_up = std::move(sock_up); loop.RequestTermination(); }); From ac1044e19c4a66f27a046247444b25a5a0b85c17 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Wed, 11 Jun 2025 19:51:05 -0700 Subject: [PATCH 03/15] [lldb] Move Transport class into lldb_private (NFC) (#143806) Move lldb-dap's Transport class into lldb_private so the code can be shared between the "JSON with header" protocol used by DAP and the JSON RPC protocol used by MCP (see [1]). [1]: https://discourse.llvm.org/t/rfc-adding-mcp-support-to-lldb/86798 (cherry picked from commit de51b2dd3c6fc995e7db56fc50b4c8dceddc0aab) --- lldb/include/lldb/Host/JSONTransport.h | 126 +++++++++++++++++++ lldb/source/Host/CMakeLists.txt | 3 +- lldb/source/Host/common/JSONTransport.cpp | 147 ++++++++++++++++++++++ 3 files changed, 275 insertions(+), 1 deletion(-) create mode 100644 lldb/include/lldb/Host/JSONTransport.h create mode 100644 lldb/source/Host/common/JSONTransport.cpp diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h new file mode 100644 index 0000000000000..4db5e417ea852 --- /dev/null +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -0,0 +1,126 @@ +//===-- JSONTransport.h ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transport layer for encoding and decoding JSON protocol messages. +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_HOST_JSONTRANSPORT_H +#define LLDB_HOST_JSONTRANSPORT_H + +#include "lldb/lldb-forward.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/JSON.h" +#include +#include + +namespace lldb_private { + +class TransportEOFError : public llvm::ErrorInfo { +public: + static char ID; + + TransportEOFError() = default; + + void log(llvm::raw_ostream &OS) const override { + OS << "transport end of file reached"; + } + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } +}; + +class TransportTimeoutError : public llvm::ErrorInfo { +public: + static char ID; + + TransportTimeoutError() = default; + + void log(llvm::raw_ostream &OS) const override { + OS << "transport operation timed out"; + } + std::error_code convertToErrorCode() const override { + return std::make_error_code(std::errc::timed_out); + } +}; + +class TransportClosedError : public llvm::ErrorInfo { +public: + static char ID; + + TransportClosedError() = default; + + void log(llvm::raw_ostream &OS) const override { + OS << "transport is closed"; + } + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } +}; + +/// A transport class that uses JSON for communication. +class JSONTransport { +public: + JSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output); + virtual ~JSONTransport() = default; + + /// Transport is not copyable. + /// @{ + JSONTransport(const JSONTransport &rhs) = delete; + void operator=(const JSONTransport &rhs) = delete; + /// @} + + /// Writes a message to the output stream. + template llvm::Error Write(const T &t) { + const std::string message = llvm::formatv("{0}", toJSON(t)).str(); + return WriteImpl(message); + } + + /// Reads the next message from the input stream. + template + llvm::Expected Read(const std::chrono::microseconds &timeout) { + llvm::Expected message = ReadImpl(timeout); + if (!message) + return message.takeError(); + return llvm::json::parse(/*JSON=*/*message); + } + +protected: + virtual void Log(llvm::StringRef message); + + virtual llvm::Error WriteImpl(const std::string &message) = 0; + virtual llvm::Expected + ReadImpl(const std::chrono::microseconds &timeout) = 0; + + lldb::IOObjectSP m_input; + lldb::IOObjectSP m_output; +}; + +/// A transport class for JSON with a HTTP header. +class HTTPDelimitedJSONTransport : public JSONTransport { +public: + HTTPDelimitedJSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) + : JSONTransport(input, output) {} + virtual ~HTTPDelimitedJSONTransport() = default; + +protected: + virtual llvm::Error WriteImpl(const std::string &message) override; + virtual llvm::Expected + ReadImpl(const std::chrono::microseconds &timeout) override; + + // FIXME: Support any header. + static constexpr llvm::StringLiteral kHeaderContentLength = + "Content-Length: "; + static constexpr llvm::StringLiteral kHeaderSeparator = "\r\n\r\n"; +}; + +} // namespace lldb_private + +#endif diff --git a/lldb/source/Host/CMakeLists.txt b/lldb/source/Host/CMakeLists.txt index 8b96bb1451fce..e60e0860a90ca 100644 --- a/lldb/source/Host/CMakeLists.txt +++ b/lldb/source/Host/CMakeLists.txt @@ -24,8 +24,9 @@ add_host_subdirectory(common common/HostNativeThreadBase.cpp common/HostProcess.cpp common/HostThread.cpp - common/LockFileBase.cpp + common/JSONTransport.cpp common/LZMA.cpp + common/LockFileBase.cpp common/MainLoopBase.cpp common/MonitoringProcessLauncher.cpp common/NativeProcessProtocol.cpp diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp new file mode 100644 index 0000000000000..103c76d25daf7 --- /dev/null +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -0,0 +1,147 @@ +//===-- JSONTransport.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Host/JSONTransport.h" +#include "lldb/Utility/IOObject.h" +#include "lldb/Utility/LLDBLog.h" +#include "lldb/Utility/Log.h" +#include "lldb/Utility/SelectHelper.h" +#include "lldb/Utility/Status.h" +#include "lldb/lldb-forward.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; + +/// ReadFull attempts to read the specified number of bytes. If EOF is +/// encountered, an empty string is returned. +static Expected +ReadFull(IOObject &descriptor, size_t length, + std::optional timeout = std::nullopt) { + if (!descriptor.IsValid()) + return llvm::make_error(); + + bool timeout_supported = true; + // FIXME: SelectHelper does not work with NativeFile on Win32. +#if _WIN32 + timeout_supported = descriptor.GetFdType() == IOObject::eFDTypeSocket; +#endif + + if (timeout && timeout_supported) { + SelectHelper sh; + sh.SetTimeout(*timeout); + sh.FDSetRead(descriptor.GetWaitableHandle()); + Status status = sh.Select(); + if (status.Fail()) { + // Convert timeouts into a specific error. + if (status.GetType() == lldb::eErrorTypePOSIX && + status.GetError() == ETIMEDOUT) + return make_error(); + return status.takeError(); + } + } + + std::string data; + data.resize(length); + Status status = descriptor.Read(data.data(), length); + if (status.Fail()) + return status.takeError(); + + // Read returns '' on EOF. + if (length == 0) + return make_error(); + + // Return the actual number of bytes read. + return data.substr(0, length); +} + +static Expected +ReadUntil(IOObject &descriptor, StringRef delimiter, + std::optional timeout = std::nullopt) { + std::string buffer; + buffer.reserve(delimiter.size() + 1); + while (!llvm::StringRef(buffer).ends_with(delimiter)) { + Expected next = + ReadFull(descriptor, buffer.empty() ? delimiter.size() : 1, timeout); + if (auto Err = next.takeError()) + return std::move(Err); + buffer += *next; + } + return buffer.substr(0, buffer.size() - delimiter.size()); +} + +JSONTransport::JSONTransport(IOObjectSP input, IOObjectSP output) + : m_input(std::move(input)), m_output(std::move(output)) {} + +void JSONTransport::Log(llvm::StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); +} + +Expected +HTTPDelimitedJSONTransport::ReadImpl(const std::chrono::microseconds &timeout) { + if (!m_input || !m_input->IsValid()) + return createStringError("transport output is closed"); + + IOObject *input = m_input.get(); + Expected message_header = + ReadFull(*input, kHeaderContentLength.size(), timeout); + if (!message_header) + return message_header.takeError(); + if (*message_header != kHeaderContentLength) + return createStringError(formatv("expected '{0}' and got '{1}'", + kHeaderContentLength, *message_header) + .str()); + + Expected raw_length = ReadUntil(*input, kHeaderSeparator); + if (!raw_length) + return handleErrors(raw_length.takeError(), + [&](const TransportEOFError &E) -> llvm::Error { + return createStringError( + "unexpected EOF while reading header separator"); + }); + + size_t length; + if (!to_integer(*raw_length, length)) + return createStringError( + formatv("invalid content length {0}", *raw_length).str()); + + Expected raw_json = ReadFull(*input, length); + if (!raw_json) + return handleErrors( + raw_json.takeError(), [&](const TransportEOFError &E) -> llvm::Error { + return createStringError("unexpected EOF while reading JSON"); + }); + + Log(llvm::formatv("--> {0}", *raw_json).str()); + + return raw_json; +} + +Error HTTPDelimitedJSONTransport::WriteImpl(const std::string &message) { + if (!m_output || !m_output->IsValid()) + return llvm::make_error(); + + Log(llvm::formatv("<-- {0}", message).str()); + + std::string Output; + raw_string_ostream OS(Output); + OS << kHeaderContentLength << message.length() << kHeaderSeparator << message; + size_t num_bytes = Output.size(); + return m_output->Write(Output.data(), num_bytes).takeError(); +} + +char TransportEOFError::ID; +char TransportTimeoutError::ID; +char TransportClosedError::ID; From d311d186a0b183f38e025c95b096e4f158297d6d Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Thu, 12 Jun 2025 14:52:43 -0700 Subject: [PATCH 04/15] [lldb] Implement JSON RPC (newline delimited) Transport (#143946) This PR implements JSON RPC-style (i.e. newline delimited) JSON transport. I moved the existing transport tests from DAP to Host and moved the PipeTest base class into TestingSupport so it can be shared by both. (cherry picked from commit 8a2895ad89793591cd3f0114bc56cd345f651823) --- lldb/include/lldb/Host/JSONTransport.h | 23 ++- lldb/source/Host/common/JSONTransport.cpp | 37 +++- lldb/unittests/DAP/TestBase.cpp | 129 +++++++++++++ lldb/unittests/Host/CMakeLists.txt | 1 + lldb/unittests/Host/JSONTransportTest.cpp | 176 ++++++++++++++++++ .../TestingSupport/Host/PipeTestUtilities.h | 28 +++ 6 files changed, 386 insertions(+), 8 deletions(-) create mode 100644 lldb/unittests/DAP/TestBase.cpp create mode 100644 lldb/unittests/Host/JSONTransportTest.cpp create mode 100644 lldb/unittests/TestingSupport/Host/PipeTestUtilities.h diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 4db5e417ea852..4087cdf2b42f7 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -51,17 +51,17 @@ class TransportTimeoutError : public llvm::ErrorInfo { } }; -class TransportClosedError : public llvm::ErrorInfo { +class TransportInvalidError : public llvm::ErrorInfo { public: static char ID; - TransportClosedError() = default; + TransportInvalidError() = default; void log(llvm::raw_ostream &OS) const override { - OS << "transport is closed"; + OS << "transport IO object invalid"; } std::error_code convertToErrorCode() const override { - return llvm::inconvertibleErrorCode(); + return std::make_error_code(std::errc::not_connected); } }; @@ -121,6 +121,21 @@ class HTTPDelimitedJSONTransport : public JSONTransport { static constexpr llvm::StringLiteral kHeaderSeparator = "\r\n\r\n"; }; +/// A transport class for JSON RPC. +class JSONRPCTransport : public JSONTransport { +public: + JSONRPCTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) + : JSONTransport(input, output) {} + virtual ~JSONRPCTransport() = default; + +protected: + virtual llvm::Error WriteImpl(const std::string &message) override; + virtual llvm::Expected + ReadImpl(const std::chrono::microseconds &timeout) override; + + static constexpr llvm::StringLiteral kMessageSeparator = "\n"; +}; + } // namespace lldb_private #endif diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index 103c76d25daf7..1a0851d5c4365 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -31,7 +31,7 @@ static Expected ReadFull(IOObject &descriptor, size_t length, std::optional timeout = std::nullopt) { if (!descriptor.IsValid()) - return llvm::make_error(); + return llvm::make_error(); bool timeout_supported = true; // FIXME: SelectHelper does not work with NativeFile on Win32. @@ -92,7 +92,7 @@ void JSONTransport::Log(llvm::StringRef message) { Expected HTTPDelimitedJSONTransport::ReadImpl(const std::chrono::microseconds &timeout) { if (!m_input || !m_input->IsValid()) - return createStringError("transport output is closed"); + return llvm::make_error(); IOObject *input = m_input.get(); Expected message_header = @@ -131,7 +131,7 @@ HTTPDelimitedJSONTransport::ReadImpl(const std::chrono::microseconds &timeout) { Error HTTPDelimitedJSONTransport::WriteImpl(const std::string &message) { if (!m_output || !m_output->IsValid()) - return llvm::make_error(); + return llvm::make_error(); Log(llvm::formatv("<-- {0}", message).str()); @@ -142,6 +142,35 @@ Error HTTPDelimitedJSONTransport::WriteImpl(const std::string &message) { return m_output->Write(Output.data(), num_bytes).takeError(); } +Expected +JSONRPCTransport::ReadImpl(const std::chrono::microseconds &timeout) { + if (!m_input || !m_input->IsValid()) + return make_error(); + + IOObject *input = m_input.get(); + Expected raw_json = + ReadUntil(*input, kMessageSeparator, timeout); + if (!raw_json) + return raw_json.takeError(); + + Log(llvm::formatv("--> {0}", *raw_json).str()); + + return *raw_json; +} + +Error JSONRPCTransport::WriteImpl(const std::string &message) { + if (!m_output || !m_output->IsValid()) + return llvm::make_error(); + + Log(llvm::formatv("<-- {0}", message).str()); + + std::string Output; + llvm::raw_string_ostream OS(Output); + OS << message << kMessageSeparator; + size_t num_bytes = Output.size(); + return m_output->Write(Output.data(), num_bytes).takeError(); +} + char TransportEOFError::ID; char TransportTimeoutError::ID; -char TransportClosedError::ID; +char TransportInvalidError::ID; diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp new file mode 100644 index 0000000000000..27ad42686fbbf --- /dev/null +++ b/lldb/unittests/DAP/TestBase.cpp @@ -0,0 +1,129 @@ +//===-- TestBase.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestBase.h" +#include "Protocol/ProtocolBase.h" +#include "TestingSupport/TestUtilities.h" +#include "lldb/API/SBDefines.h" +#include "lldb/API/SBStructuredData.h" +#include "lldb/Host/File.h" +#include "lldb/Host/Pipe.h" +#include "lldb/lldb-forward.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" +#include + +using namespace llvm; +using namespace lldb; +using namespace lldb_dap; +using namespace lldb_dap::protocol; +using namespace lldb_dap_tests; +using lldb_private::File; +using lldb_private::NativeFile; +using lldb_private::Pipe; + +void TransportBase::SetUp() { + PipeTest::SetUp(); + to_dap = std::make_unique( + "to_dap", nullptr, + std::make_shared(input.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(output.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned)); + from_dap = std::make_unique( + "from_dap", nullptr, + std::make_shared(output.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(input.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned)); +} + +void DAPTestBase::SetUp() { + TransportBase::SetUp(); + dap = std::make_unique( + /*log=*/nullptr, + /*default_repl_mode=*/ReplMode::Auto, + /*pre_init_commands=*/std::vector(), + /*transport=*/*to_dap); +} + +void DAPTestBase::TearDown() { + if (core) + ASSERT_THAT_ERROR(core->discard(), Succeeded()); + if (binary) + ASSERT_THAT_ERROR(binary->discard(), Succeeded()); +} + +void DAPTestBase::SetUpTestSuite() { + lldb::SBError error = SBDebugger::InitializeWithErrorHandling(); + EXPECT_TRUE(error.Success()); +} +void DAPTestBase::TeatUpTestSuite() { SBDebugger::Terminate(); } + +bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { + EXPECT_TRUE(dap->debugger); + + lldb::SBStructuredData data = dap->debugger.GetBuildConfiguration() + .GetValueForKey("targets") + .GetValueForKey("value"); + for (size_t i = 0; i < data.GetSize(); i++) { + char buf[100] = {0}; + size_t size = data.GetItemAtIndex(i).GetStringValue(buf, sizeof(buf)); + if (llvm::StringRef(buf, size) == platform) + return true; + } + + return false; +} + +void DAPTestBase::CreateDebugger() { + dap->debugger = lldb::SBDebugger::Create(); + ASSERT_TRUE(dap->debugger); +} + +void DAPTestBase::LoadCore() { + ASSERT_TRUE(dap->debugger); + llvm::Expected binary_yaml = + lldb_private::TestFile::fromYamlFile(k_linux_binary); + ASSERT_THAT_EXPECTED(binary_yaml, Succeeded()); + llvm::Expected binary_file = + binary_yaml->writeToTemporaryFile(); + ASSERT_THAT_EXPECTED(binary_file, Succeeded()); + binary = std::move(*binary_file); + dap->target = dap->debugger.CreateTarget(binary->TmpName.data()); + ASSERT_TRUE(dap->target); + llvm::Expected core_yaml = + lldb_private::TestFile::fromYamlFile(k_linux_core); + ASSERT_THAT_EXPECTED(core_yaml, Succeeded()); + llvm::Expected core_file = + core_yaml->writeToTemporaryFile(); + ASSERT_THAT_EXPECTED(core_file, Succeeded()); + this->core = std::move(*core_file); + SBProcess process = dap->target.LoadCore(this->core->TmpName.data()); + ASSERT_TRUE(process); +} + +std::vector DAPTestBase::DrainOutput() { + std::vector msgs; + output.CloseWriteFileDescriptor(); + while (true) { + Expected next = + from_dap->Read(std::chrono::milliseconds(1)); + if (!next) { + consumeError(next.takeError()); + break; + } + msgs.push_back(*next); + } + return msgs; +} diff --git a/lldb/unittests/Host/CMakeLists.txt b/lldb/unittests/Host/CMakeLists.txt index e2cb0a9e5713a..7c7fabf9716e0 100644 --- a/lldb/unittests/Host/CMakeLists.txt +++ b/lldb/unittests/Host/CMakeLists.txt @@ -13,6 +13,7 @@ set (FILES HostInfoTest.cpp HostTest.cpp MainLoopTest.cpp + JSONTransportTest.cpp NativeProcessProtocolTest.cpp PipeTest.cpp ProcessLaunchInfoTest.cpp diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp new file mode 100644 index 0000000000000..f1ec5e03bbeca --- /dev/null +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -0,0 +1,176 @@ +//===-- JSONTransportTest.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Host/JSONTransport.h" +#include "TestingSupport/Host/PipeTestUtilities.h" +#include "lldb/Host/File.h" + +using namespace llvm; +using namespace lldb_private; + +namespace { +template class JSONTransportTest : public PipeTest { +protected: + std::unique_ptr transport; + + void SetUp() override { + PipeTest::SetUp(); + transport = std::make_unique( + std::make_shared(input.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(output.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned)); + } +}; + +class HTTPDelimitedJSONTransportTest + : public JSONTransportTest { +public: + using JSONTransportTest::JSONTransportTest; +}; + +class JSONRPCTransportTest : public JSONTransportTest { +public: + using JSONTransportTest::JSONTransportTest; +}; + +struct JSONTestType { + std::string str; +}; + +llvm::json::Value toJSON(const JSONTestType &T) { + return llvm::json::Object{{"str", T.str}}; +} + +bool fromJSON(const llvm::json::Value &V, JSONTestType &T, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("str", T.str); +} +} // namespace + +TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { + std::string malformed_header = "COnTent-LenGth: -1{}\r\n\r\nnotjosn"; + ASSERT_THAT_EXPECTED( + input.Write(malformed_header.data(), malformed_header.size()), + Succeeded()); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + FailedWithMessage( + "expected 'Content-Length: ' and got 'COnTent-LenGth: '")); +} + +TEST_F(HTTPDelimitedJSONTransportTest, Read) { + std::string json = R"json({"str": "foo"})json"; + std::string message = + formatv("Content-Length: {0}\r\n\r\n{1}", json.size(), json).str(); + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), + Succeeded()); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + HasValue(testing::FieldsAre(/*str=*/"foo"))); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadWithEOF) { + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadAfterClosed) { + input.CloseReadFileDescriptor(); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + llvm::Failed()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { + transport = std::make_unique(nullptr, nullptr); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, Write) { + ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + output.CloseWriteFileDescriptor(); + char buf[1024]; + Expected bytes_read = + output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); + ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); + ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n" + R"json({"str":"foo"})json")); +} + +TEST_F(JSONRPCTransportTest, MalformedRequests) { + std::string malformed_header = "notjson\n"; + ASSERT_THAT_EXPECTED( + input.Write(malformed_header.data(), malformed_header.size()), + Succeeded()); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + llvm::Failed()); +} + +TEST_F(JSONRPCTransportTest, Read) { + std::string json = R"json({"str": "foo"})json"; + std::string message = formatv("{0}\n", json).str(); + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), + Succeeded()); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + HasValue(testing::FieldsAre(/*str=*/"foo"))); +} + +TEST_F(JSONRPCTransportTest, ReadWithEOF) { + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} + +TEST_F(JSONRPCTransportTest, ReadAfterClosed) { + input.CloseReadFileDescriptor(); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + llvm::Failed()); +} + +TEST_F(JSONRPCTransportTest, Write) { + ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + output.CloseWriteFileDescriptor(); + char buf[1024]; + Expected bytes_read = + output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); + ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); + ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"json({"str":"foo"})json" + "\n")); +} + +TEST_F(JSONRPCTransportTest, InvalidTransport) { + transport = std::make_unique(nullptr, nullptr); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} + +#ifndef _WIN32 +TEST_F(HTTPDelimitedJSONTransportTest, ReadWithTimeout) { + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} + +TEST_F(JSONRPCTransportTest, ReadWithTimeout) { + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} +#endif diff --git a/lldb/unittests/TestingSupport/Host/PipeTestUtilities.h b/lldb/unittests/TestingSupport/Host/PipeTestUtilities.h new file mode 100644 index 0000000000000..50d5d4117c898 --- /dev/null +++ b/lldb/unittests/TestingSupport/Host/PipeTestUtilities.h @@ -0,0 +1,28 @@ +//===-- PipeTestUtilities.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_PIPETESTUTILITIES_H +#define LLDB_UNITTESTS_TESTINGSUPPORT_PIPETESTUTILITIES_H + +#include "lldb/Host/Pipe.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +/// A base class for tests that need a pair of pipes for communication. +class PipeTest : public testing::Test { +protected: + lldb_private::Pipe input; + lldb_private::Pipe output; + + void SetUp() override { + ASSERT_THAT_ERROR(input.CreateNew(false).ToError(), llvm::Succeeded()); + ASSERT_THAT_ERROR(output.CreateNew(false).ToError(), llvm::Succeeded()); + } +}; + +#endif From b1c21de55d2f853704e6e030364297b08b998f87 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Fri, 20 Jun 2025 10:48:04 -0500 Subject: [PATCH 05/15] [lldb] Add Model Context Protocol (MCP) support to LLDB (#143628) This PR adds an MCP (Model Context Protocol ) server to LLDB. For motivation and background, please refer to the corresponding RFC: https://discourse.llvm.org/t/rfc-adding-mcp-support-to-lldb/86798 I implemented this as a new kind of plugin. The idea is that we could support multiple protocol servers (e.g. if we want to support DAP from within LLDB). This also introduces a corresponding top-level command (`protocol-server`) with two subcommands to `start` and `stop` the server. ``` (lldb) protocol-server start MCP tcp://localhost:1234 MCP server started with connection listeners: connection://[::1]:1234, connection://[127.0.0.1]:1234 ``` The MCP sever supports one tool (`lldb_command`) which executes a command, but can easily be extended with more commands. (cherry picked from commit 9524bfb27020d31b9474f595b7c0e5d2e1ac65f5) --- lldb/cmake/modules/LLDBConfig.cmake | 1 + lldb/include/lldb/Core/Debugger.h | 6 + lldb/include/lldb/Core/PluginManager.h | 11 + lldb/include/lldb/Core/ProtocolServer.h | 39 +++ .../Interpreter/CommandOptionArgumentTable.h | 1 + lldb/include/lldb/lldb-enumerations.h | 1 + lldb/include/lldb/lldb-forward.h | 3 +- lldb/include/lldb/lldb-private-interfaces.h | 2 + lldb/source/Commands/CMakeLists.txt | 1 + .../Commands/CommandObjectProtocolServer.cpp | 176 ++++++++++ .../Commands/CommandObjectProtocolServer.h | 25 ++ lldb/source/Core/CMakeLists.txt | 1 + lldb/source/Core/Debugger.cpp | 24 ++ lldb/source/Core/PluginManager.cpp | 32 ++ lldb/source/Core/ProtocolServer.cpp | 21 ++ .../source/Interpreter/CommandInterpreter.cpp | 2 + lldb/source/Plugins/CMakeLists.txt | 4 + lldb/source/Plugins/Protocol/CMakeLists.txt | 1 + .../Plugins/Protocol/MCP/CMakeLists.txt | 13 + lldb/source/Plugins/Protocol/MCP/MCPError.cpp | 34 ++ lldb/source/Plugins/Protocol/MCP/MCPError.h | 33 ++ lldb/source/Plugins/Protocol/MCP/Protocol.cpp | 214 ++++++++++++ lldb/source/Plugins/Protocol/MCP/Protocol.h | 128 +++++++ .../Protocol/MCP/ProtocolServerMCP.cpp | 327 ++++++++++++++++++ .../Plugins/Protocol/MCP/ProtocolServerMCP.h | 100 ++++++ lldb/source/Plugins/Protocol/MCP/Tool.cpp | 81 +++++ lldb/source/Plugins/Protocol/MCP/Tool.h | 56 +++ lldb/unittests/CMakeLists.txt | 4 + lldb/unittests/Protocol/CMakeLists.txt | 12 + .../Protocol/ProtocolMCPServerTest.cpp | 291 ++++++++++++++++ lldb/unittests/Protocol/ProtocolMCPTest.cpp | 135 ++++++++ lldb/unittests/TestingSupport/TestUtilities.h | 9 + 32 files changed, 1787 insertions(+), 1 deletion(-) create mode 100644 lldb/include/lldb/Core/ProtocolServer.h create mode 100644 lldb/source/Commands/CommandObjectProtocolServer.cpp create mode 100644 lldb/source/Commands/CommandObjectProtocolServer.h create mode 100644 lldb/source/Core/ProtocolServer.cpp create mode 100644 lldb/source/Plugins/Protocol/CMakeLists.txt create mode 100644 lldb/source/Plugins/Protocol/MCP/CMakeLists.txt create mode 100644 lldb/source/Plugins/Protocol/MCP/MCPError.cpp create mode 100644 lldb/source/Plugins/Protocol/MCP/MCPError.h create mode 100644 lldb/source/Plugins/Protocol/MCP/Protocol.cpp create mode 100644 lldb/source/Plugins/Protocol/MCP/Protocol.h create mode 100644 lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp create mode 100644 lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h create mode 100644 lldb/source/Plugins/Protocol/MCP/Tool.cpp create mode 100644 lldb/source/Plugins/Protocol/MCP/Tool.h create mode 100644 lldb/unittests/Protocol/CMakeLists.txt create mode 100644 lldb/unittests/Protocol/ProtocolMCPServerTest.cpp create mode 100644 lldb/unittests/Protocol/ProtocolMCPTest.cpp diff --git a/lldb/cmake/modules/LLDBConfig.cmake b/lldb/cmake/modules/LLDBConfig.cmake index 70e8db40328af..23ccae5e11fa8 100644 --- a/lldb/cmake/modules/LLDBConfig.cmake +++ b/lldb/cmake/modules/LLDBConfig.cmake @@ -77,6 +77,7 @@ add_optional_dependency(LLDB_ENABLE_FBSDVMCORE "Enable libfbsdvmcore support in option(LLDB_USE_ENTITLEMENTS "When codesigning, use entitlements if available" ON) option(LLDB_BUILD_FRAMEWORK "Build LLDB.framework (Darwin only)" OFF) +option(LLDB_ENABLE_PROTOCOL_SERVERS "Enable protocol servers (e.g. MCP) in LLDB" ON) option(LLDB_NO_INSTALL_DEFAULT_RPATH "Disable default RPATH settings in binaries" OFF) option(LLDB_USE_SYSTEM_DEBUGSERVER "Use the system's debugserver for testing (Darwin only)." OFF) option(LLDB_SKIP_STRIP "Whether to skip stripping of binaries when installing lldb." OFF) diff --git a/lldb/include/lldb/Core/Debugger.h b/lldb/include/lldb/Core/Debugger.h index 35a41e419c9bf..625ecf2ed26fc 100644 --- a/lldb/include/lldb/Core/Debugger.h +++ b/lldb/include/lldb/Core/Debugger.h @@ -617,6 +617,10 @@ class Debugger : public std::enable_shared_from_this, void FlushProcessOutput(Process &process, bool flush_stdout, bool flush_stderr); + void AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp); + void RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp); + lldb::ProtocolServerSP GetProtocolServer(llvm::StringRef protocol) const; + SourceManager::SourceFileCache &GetSourceFileCache() { return m_source_file_cache; } @@ -789,6 +793,8 @@ class Debugger : public std::enable_shared_from_this, mutable std::mutex m_progress_reports_mutex; /// @} + llvm::SmallVector m_protocol_servers; + std::mutex m_destroy_callback_mutex; lldb::callback_token_t m_destroy_callback_next_token = 0; struct DestroyCallbackInfo { diff --git a/lldb/include/lldb/Core/PluginManager.h b/lldb/include/lldb/Core/PluginManager.h index 0c988e5969538..96bf10fa48d38 100644 --- a/lldb/include/lldb/Core/PluginManager.h +++ b/lldb/include/lldb/Core/PluginManager.h @@ -255,6 +255,17 @@ class PluginManager { static void AutoCompleteProcessName(llvm::StringRef partial_name, CompletionRequest &request); + // Protocol + static bool RegisterPlugin(llvm::StringRef name, llvm::StringRef description, + ProtocolServerCreateInstance create_callback); + + static bool UnregisterPlugin(ProtocolServerCreateInstance create_callback); + + static llvm::StringRef GetProtocolServerPluginNameAtIndex(uint32_t idx); + + static ProtocolServerCreateInstance + GetProtocolCreateCallbackForPluginName(llvm::StringRef name); + // Register Type Provider static bool RegisterPlugin(llvm::StringRef name, llvm::StringRef description, RegisterTypeBuilderCreateInstance create_callback); diff --git a/lldb/include/lldb/Core/ProtocolServer.h b/lldb/include/lldb/Core/ProtocolServer.h new file mode 100644 index 0000000000000..fafe460904323 --- /dev/null +++ b/lldb/include/lldb/Core/ProtocolServer.h @@ -0,0 +1,39 @@ +//===-- ProtocolServer.h --------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_CORE_PROTOCOLSERVER_H +#define LLDB_CORE_PROTOCOLSERVER_H + +#include "lldb/Core/PluginInterface.h" +#include "lldb/Host/Socket.h" +#include "lldb/lldb-private-interfaces.h" + +namespace lldb_private { + +class ProtocolServer : public PluginInterface { +public: + ProtocolServer() = default; + virtual ~ProtocolServer() = default; + + static lldb::ProtocolServerSP Create(llvm::StringRef name, + Debugger &debugger); + + struct Connection { + Socket::SocketProtocol protocol; + std::string name; + }; + + virtual llvm::Error Start(Connection connection) = 0; + virtual llvm::Error Stop() = 0; + + virtual Socket *GetSocket() const = 0; +}; + +} // namespace lldb_private + +#endif diff --git a/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h b/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h index 323f519ede053..8fb3e9e95c83d 100644 --- a/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h +++ b/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h @@ -337,6 +337,7 @@ static constexpr CommandObject::ArgumentTableEntry g_argument_table[] = { { lldb::eArgTypeModule, "module", lldb::CompletionType::eModuleCompletion, {}, { nullptr, false }, "The name of a module loaded into the current target." }, { lldb::eArgTypeCPUName, "cpu-name", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The name of a CPU." }, { lldb::eArgTypeCPUFeatures, "cpu-features", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The CPU feature string." }, + { lldb::eArgTypeProtocol, "protocol", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The name of the protocol." }, // clang-format on }; diff --git a/lldb/include/lldb/lldb-enumerations.h b/lldb/include/lldb/lldb-enumerations.h index 882640eccc3d2..42ec593e2ee42 100644 --- a/lldb/include/lldb/lldb-enumerations.h +++ b/lldb/include/lldb/lldb-enumerations.h @@ -669,6 +669,7 @@ enum CommandArgumentType { eArgTypeModule, eArgTypeCPUName, eArgTypeCPUFeatures, + eArgTypeProtocol, eArgTypeLastArg // Always keep this entry as the last entry in this // enumeration!! }; diff --git a/lldb/include/lldb/lldb-forward.h b/lldb/include/lldb/lldb-forward.h index a3550f3fe60ff..cdcd95443cc7a 100644 --- a/lldb/include/lldb/lldb-forward.h +++ b/lldb/include/lldb/lldb-forward.h @@ -164,13 +164,13 @@ class PersistentExpressionState; class Platform; class Process; class ProcessAttachInfo; -class ProcessLaunchInfo; class ProcessInfo; class ProcessInstanceInfo; class ProcessInstanceInfoMatch; class ProcessLaunchInfo; class ProcessModID; class Property; +class ProtocolServer; class Queue; class QueueImpl; class QueueItem; @@ -389,6 +389,7 @@ typedef std::shared_ptr PlatformSP; typedef std::shared_ptr ProcessSP; typedef std::shared_ptr ProcessAttachInfoSP; typedef std::shared_ptr ProcessLaunchInfoSP; +typedef std::shared_ptr ProtocolServerSP; typedef std::weak_ptr ProcessWP; typedef std::shared_ptr RegisterCheckpointSP; typedef std::shared_ptr RegisterContextSP; diff --git a/lldb/include/lldb/lldb-private-interfaces.h b/lldb/include/lldb/lldb-private-interfaces.h index cd5ccc44324c3..19ab5f435659b 100644 --- a/lldb/include/lldb/lldb-private-interfaces.h +++ b/lldb/include/lldb/lldb-private-interfaces.h @@ -82,6 +82,8 @@ typedef lldb::PlatformSP (*PlatformCreateInstance)(bool force, typedef lldb::ProcessSP (*ProcessCreateInstance)( lldb::TargetSP target_sp, lldb::ListenerSP listener_sp, const FileSpec *crash_file_path, bool can_connect); +typedef lldb::ProtocolServerSP (*ProtocolServerCreateInstance)( + Debugger &debugger); typedef lldb::RegisterTypeBuilderSP (*RegisterTypeBuilderCreateInstance)( Target &target); typedef lldb::ScriptInterpreterSP (*ScriptInterpreterCreateInstance)( diff --git a/lldb/source/Commands/CMakeLists.txt b/lldb/source/Commands/CMakeLists.txt index 186d778305a4e..fab0e303d8b10 100644 --- a/lldb/source/Commands/CMakeLists.txt +++ b/lldb/source/Commands/CMakeLists.txt @@ -28,6 +28,7 @@ add_lldb_library(lldbCommands NO_PLUGIN_DEPENDENCIES CommandObjectPlatform.cpp CommandObjectPlugin.cpp CommandObjectProcess.cpp + CommandObjectProtocolServer.cpp CommandObjectQuit.cpp CommandObjectRegexCommand.cpp CommandObjectRegister.cpp diff --git a/lldb/source/Commands/CommandObjectProtocolServer.cpp b/lldb/source/Commands/CommandObjectProtocolServer.cpp new file mode 100644 index 0000000000000..420fc5fdddadb --- /dev/null +++ b/lldb/source/Commands/CommandObjectProtocolServer.cpp @@ -0,0 +1,176 @@ +//===-- CommandObjectProtocolServer.cpp +//----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CommandObjectProtocolServer.h" +#include "lldb/Core/PluginManager.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/Socket.h" +#include "lldb/Interpreter/CommandInterpreter.h" +#include "lldb/Interpreter/CommandReturnObject.h" +#include "lldb/Utility/UriParser.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatAdapters.h" + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; + +#define LLDB_OPTIONS_mcp +#include "CommandOptions.inc" + +static std::vector GetSupportedProtocols() { + std::vector supported_protocols; + size_t i = 0; + + for (llvm::StringRef protocol_name = + PluginManager::GetProtocolServerPluginNameAtIndex(i++); + !protocol_name.empty(); + protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) { + supported_protocols.push_back(protocol_name); + } + + return supported_protocols; +} + +class CommandObjectProtocolServerStart : public CommandObjectParsed { +public: + CommandObjectProtocolServerStart(CommandInterpreter &interpreter) + : CommandObjectParsed(interpreter, "protocol-server start", + "start protocol server", + "protocol-server start ") { + AddSimpleArgumentList(lldb::eArgTypeProtocol, eArgRepeatPlain); + AddSimpleArgumentList(lldb::eArgTypeConnectURL, eArgRepeatPlain); + } + + ~CommandObjectProtocolServerStart() override = default; + +protected: + void DoExecute(Args &args, CommandReturnObject &result) override { + if (args.GetArgumentCount() < 1) { + result.AppendError("no protocol specified"); + return; + } + + llvm::StringRef protocol = args.GetArgumentAtIndex(0); + std::vector supported_protocols = GetSupportedProtocols(); + if (llvm::find(supported_protocols, protocol) == + supported_protocols.end()) { + result.AppendErrorWithFormatv( + "unsupported protocol: {0}. Supported protocols are: {1}", protocol, + llvm::join(GetSupportedProtocols(), ", ")); + return; + } + + if (args.GetArgumentCount() < 2) { + result.AppendError("no connection specified"); + return; + } + llvm::StringRef connection_uri = args.GetArgumentAtIndex(1); + + ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol); + if (!server_sp) + server_sp = ProtocolServer::Create(protocol, GetDebugger()); + + const char *connection_error = + "unsupported connection specifier, expected 'accept:///path' or " + "'listen://[host]:port', got '{0}'."; + auto uri = lldb_private::URI::Parse(connection_uri); + if (!uri) { + result.AppendErrorWithFormatv(connection_error, connection_uri); + return; + } + + std::optional protocol_and_mode = + Socket::GetProtocolAndMode(uri->scheme); + if (!protocol_and_mode || protocol_and_mode->second != Socket::ModeAccept) { + result.AppendErrorWithFormatv(connection_error, connection_uri); + return; + } + + ProtocolServer::Connection connection; + connection.protocol = protocol_and_mode->first; + connection.name = + formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname, + uri->port.value_or(0)); + + if (llvm::Error error = server_sp->Start(connection)) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + + GetDebugger().AddProtocolServer(server_sp); + + if (Socket *socket = server_sp->GetSocket()) { + std::string address = + llvm::join(socket->GetListeningConnectionURI(), ", "); + result.AppendMessageWithFormatv( + "{0} server started with connection listeners: {1}", protocol, + address); + } + } +}; + +class CommandObjectProtocolServerStop : public CommandObjectParsed { +public: + CommandObjectProtocolServerStop(CommandInterpreter &interpreter) + : CommandObjectParsed(interpreter, "protocol-server stop", + "stop protocol server", + "protocol-server stop ") { + AddSimpleArgumentList(lldb::eArgTypeProtocol, eArgRepeatPlain); + } + + ~CommandObjectProtocolServerStop() override = default; + +protected: + void DoExecute(Args &args, CommandReturnObject &result) override { + if (args.GetArgumentCount() < 1) { + result.AppendError("no protocol specified"); + return; + } + + llvm::StringRef protocol = args.GetArgumentAtIndex(0); + std::vector supported_protocols = GetSupportedProtocols(); + if (llvm::find(supported_protocols, protocol) == + supported_protocols.end()) { + result.AppendErrorWithFormatv( + "unsupported protocol: {0}. Supported protocols are: {1}", protocol, + llvm::join(GetSupportedProtocols(), ", ")); + return; + } + + Debugger &debugger = GetDebugger(); + + ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol); + if (!server_sp) { + result.AppendError( + llvm::formatv("no {0} protocol server running", protocol).str()); + return; + } + + if (llvm::Error error = server_sp->Stop()) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + + debugger.RemoveProtocolServer(server_sp); + } +}; + +CommandObjectProtocolServer::CommandObjectProtocolServer( + CommandInterpreter &interpreter) + : CommandObjectMultiword(interpreter, "protocol-server", + "Start and stop a protocol server.", + "protocol-server") { + LoadSubCommand("start", CommandObjectSP(new CommandObjectProtocolServerStart( + interpreter))); + LoadSubCommand("stop", CommandObjectSP( + new CommandObjectProtocolServerStop(interpreter))); +} + +CommandObjectProtocolServer::~CommandObjectProtocolServer() = default; diff --git a/lldb/source/Commands/CommandObjectProtocolServer.h b/lldb/source/Commands/CommandObjectProtocolServer.h new file mode 100644 index 0000000000000..3591216b014cb --- /dev/null +++ b/lldb/source/Commands/CommandObjectProtocolServer.h @@ -0,0 +1,25 @@ +//===-- CommandObjectProtocolServer.h +//------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_SOURCE_COMMANDS_COMMANDOBJECTPROTOCOLSERVER_H +#define LLDB_SOURCE_COMMANDS_COMMANDOBJECTPROTOCOLSERVER_H + +#include "lldb/Interpreter/CommandObjectMultiword.h" + +namespace lldb_private { + +class CommandObjectProtocolServer : public CommandObjectMultiword { +public: + CommandObjectProtocolServer(CommandInterpreter &interpreter); + ~CommandObjectProtocolServer() override; +}; + +} // namespace lldb_private + +#endif // LLDB_SOURCE_COMMANDS_COMMANDOBJECTMCP_H diff --git a/lldb/source/Core/CMakeLists.txt b/lldb/source/Core/CMakeLists.txt index c6bb3cded801a..e15bff774e02f 100644 --- a/lldb/source/Core/CMakeLists.txt +++ b/lldb/source/Core/CMakeLists.txt @@ -48,6 +48,7 @@ add_lldb_library(lldbCore Opcode.cpp PluginManager.cpp Progress.cpp + ProtocolServer.cpp Statusline.cpp RichManglingContext.cpp SearchFilter.cpp diff --git a/lldb/source/Core/Debugger.cpp b/lldb/source/Core/Debugger.cpp index 0efc9d9a4482f..d8930ccf06d3b 100644 --- a/lldb/source/Core/Debugger.cpp +++ b/lldb/source/Core/Debugger.cpp @@ -16,6 +16,7 @@ #include "lldb/Core/ModuleSpec.h" #include "lldb/Core/PluginManager.h" #include "lldb/Core/Progress.h" +#include "lldb/Core/ProtocolServer.h" #include "lldb/Core/StreamAsynchronousIO.h" #include "lldb/DataFormatters/DataVisualization.h" #include "lldb/Expression/REPL.h" @@ -2379,3 +2380,26 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() { "Debugger::GetThreadPool called before Debugger::Initialize"); return *g_thread_pool; } + +void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) { + assert(protocol_server_sp && + GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr); + m_protocol_servers.push_back(protocol_server_sp); +} + +void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) { + auto it = llvm::find(m_protocol_servers, protocol_server_sp); + if (it != m_protocol_servers.end()) + m_protocol_servers.erase(it); +} + +lldb::ProtocolServerSP +Debugger::GetProtocolServer(llvm::StringRef protocol) const { + for (ProtocolServerSP protocol_server_sp : m_protocol_servers) { + if (!protocol_server_sp) + continue; + if (protocol_server_sp->GetPluginName() == protocol) + return protocol_server_sp; + } + return nullptr; +} diff --git a/lldb/source/Core/PluginManager.cpp b/lldb/source/Core/PluginManager.cpp index 8a19684d63f28..ed93f7dee6597 100644 --- a/lldb/source/Core/PluginManager.cpp +++ b/lldb/source/Core/PluginManager.cpp @@ -905,6 +905,38 @@ void PluginManager::AutoCompleteProcessName(llvm::StringRef name, } } +#pragma mark ProtocolServer + +typedef PluginInstance ProtocolServerInstance; +typedef PluginInstances ProtocolServerInstances; + +static ProtocolServerInstances &GetProtocolServerInstances() { + static ProtocolServerInstances g_instances; + return g_instances; +} + +bool PluginManager::RegisterPlugin( + llvm::StringRef name, llvm::StringRef description, + ProtocolServerCreateInstance create_callback) { + return GetProtocolServerInstances().RegisterPlugin(name, description, + create_callback); +} + +bool PluginManager::UnregisterPlugin( + ProtocolServerCreateInstance create_callback) { + return GetProtocolServerInstances().UnregisterPlugin(create_callback); +} + +llvm::StringRef +PluginManager::GetProtocolServerPluginNameAtIndex(uint32_t idx) { + return GetProtocolServerInstances().GetNameAtIndex(idx); +} + +ProtocolServerCreateInstance +PluginManager::GetProtocolCreateCallbackForPluginName(llvm::StringRef name) { + return GetProtocolServerInstances().GetCallbackForName(name); +} + #pragma mark RegisterTypeBuilder struct RegisterTypeBuilderInstance diff --git a/lldb/source/Core/ProtocolServer.cpp b/lldb/source/Core/ProtocolServer.cpp new file mode 100644 index 0000000000000..d57a047afa7b2 --- /dev/null +++ b/lldb/source/Core/ProtocolServer.cpp @@ -0,0 +1,21 @@ +//===-- ProtocolServer.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Core/PluginManager.h" + +using namespace lldb_private; +using namespace lldb; + +ProtocolServerSP ProtocolServer::Create(llvm::StringRef name, + Debugger &debugger) { + if (ProtocolServerCreateInstance create_callback = + PluginManager::GetProtocolCreateCallbackForPluginName(name)) + return create_callback(debugger); + return nullptr; +} diff --git a/lldb/source/Interpreter/CommandInterpreter.cpp b/lldb/source/Interpreter/CommandInterpreter.cpp index 68831d2831749..231b9c08d7150 100644 --- a/lldb/source/Interpreter/CommandInterpreter.cpp +++ b/lldb/source/Interpreter/CommandInterpreter.cpp @@ -31,6 +31,7 @@ #include "Commands/CommandObjectPlatform.h" #include "Commands/CommandObjectPlugin.h" #include "Commands/CommandObjectProcess.h" +#include "Commands/CommandObjectProtocolServer.h" #include "Commands/CommandObjectQuit.h" #include "Commands/CommandObjectRegexCommand.h" #include "Commands/CommandObjectRegister.h" @@ -583,6 +584,7 @@ void CommandInterpreter::LoadCommandDictionary() { REGISTER_COMMAND_OBJECT("platform", CommandObjectPlatform); REGISTER_COMMAND_OBJECT("plugin", CommandObjectPlugin); REGISTER_COMMAND_OBJECT("process", CommandObjectMultiwordProcess); + REGISTER_COMMAND_OBJECT("protocol-server", CommandObjectProtocolServer); REGISTER_COMMAND_OBJECT("quit", CommandObjectQuit); REGISTER_COMMAND_OBJECT("register", CommandObjectRegister); REGISTER_COMMAND_OBJECT("scripting", CommandObjectMultiwordScripting); diff --git a/lldb/source/Plugins/CMakeLists.txt b/lldb/source/Plugins/CMakeLists.txt index 854f589f45ae0..08f444e7b15e8 100644 --- a/lldb/source/Plugins/CMakeLists.txt +++ b/lldb/source/Plugins/CMakeLists.txt @@ -27,6 +27,10 @@ add_subdirectory(TraceExporter) add_subdirectory(TypeSystem) add_subdirectory(UnwindAssembly) +if(LLDB_ENABLE_PROTOCOL_SERVERS) + add_subdirectory(Protocol) +endif() + set(LLDB_STRIPPED_PLUGINS) get_property(LLDB_ALL_PLUGINS GLOBAL PROPERTY LLDB_PLUGINS) diff --git a/lldb/source/Plugins/Protocol/CMakeLists.txt b/lldb/source/Plugins/Protocol/CMakeLists.txt new file mode 100644 index 0000000000000..93b347d4cc9d8 --- /dev/null +++ b/lldb/source/Plugins/Protocol/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(MCP) diff --git a/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt new file mode 100644 index 0000000000000..db31a7a69cb33 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt @@ -0,0 +1,13 @@ +add_lldb_library(lldbPluginProtocolServerMCP PLUGIN + MCPError.cpp + Protocol.cpp + ProtocolServerMCP.cpp + Tool.cpp + + LINK_COMPONENTS + Support + + LINK_LIBS + lldbHost + lldbUtility +) diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.cpp b/lldb/source/Plugins/Protocol/MCP/MCPError.cpp new file mode 100644 index 0000000000000..5ed850066b659 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/MCPError.cpp @@ -0,0 +1,34 @@ +//===-- MCPError.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "MCPError.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace lldb_private::mcp { + +char MCPError::ID; + +MCPError::MCPError(std::string message, int64_t error_code) + : m_message(message), m_error_code(error_code) {} + +void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; } + +std::error_code MCPError::convertToErrorCode() const { + return llvm::inconvertibleErrorCode(); +} + +protocol::Error MCPError::toProtcolError() const { + protocol::Error error; + error.error.code = m_error_code; + error.error.message = m_message; + return error; +} + +} // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.h b/lldb/source/Plugins/Protocol/MCP/MCPError.h new file mode 100644 index 0000000000000..2a76a7b087e20 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/MCPError.h @@ -0,0 +1,33 @@ +//===-- MCPError.h --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/Support/Error.h" +#include + +namespace lldb_private::mcp { + +class MCPError : public llvm::ErrorInfo { +public: + static char ID; + + MCPError(std::string message, int64_t error_code); + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + + const std::string &getMessage() const { return m_message; } + + protocol::Error toProtcolError() const; + +private: + std::string m_message; + int64_t m_error_code; +}; + +} // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.cpp b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp new file mode 100644 index 0000000000000..d66c931a0b284 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp @@ -0,0 +1,214 @@ +//===- Protocol.cpp -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/Support/JSON.h" + +using namespace llvm; + +namespace lldb_private::mcp::protocol { + +static bool mapRaw(const json::Value &Params, StringLiteral Prop, + std::optional &V, json::Path P) { + const auto *O = Params.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + const json::Value *E = O->get(Prop); + if (E) + V = std::move(*E); + return true; +} + +llvm::json::Value toJSON(const Request &R) { + json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}, {"method", R.method}}; + if (R.params) + Result.insert({"params", R.params}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("id", R.id) || !O.map("method", R.method)) + return false; + return mapRaw(V, "params", R.params, P); +} + +llvm::json::Value toJSON(const ErrorInfo &EI) { + llvm::json::Object Result{{"code", EI.code}, {"message", EI.message}}; + if (EI.data) + Result.insert({"data", EI.data}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ErrorInfo &EI, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("code", EI.code) && O.map("message", EI.message) && + O.mapOptional("data", EI.data); +} + +llvm::json::Value toJSON(const Error &E) { + return json::Object{{"jsonrpc", "2.0"}, {"id", E.id}, {"error", E.error}}; +} + +bool fromJSON(const llvm::json::Value &V, Error &E, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("id", E.id) && O.map("error", E.error); +} + +llvm::json::Value toJSON(const Response &R) { + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}}; + if (R.result) + Result.insert({"result", R.result}); + if (R.error) + Result.insert({"error", R.error}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Response &R, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("id", R.id) || !O.map("error", R.error)) + return false; + return mapRaw(V, "result", R.result, P); +} + +llvm::json::Value toJSON(const Notification &N) { + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"method", N.method}}; + if (N.params) + Result.insert({"params", N.params}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Notification &N, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("method", N.method)) + return false; + auto *Obj = V.getAsObject(); + if (!Obj) + return false; + if (auto *Params = Obj->get("params")) + N.params = *Params; + return true; +} + +llvm::json::Value toJSON(const ToolCapability &TC) { + return llvm::json::Object{{"listChanged", TC.listChanged}}; +} + +bool fromJSON(const llvm::json::Value &V, ToolCapability &TC, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("listChanged", TC.listChanged); +} + +llvm::json::Value toJSON(const Capabilities &C) { + return llvm::json::Object{{"tools", C.tools}}; +} + +bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("tools", C.tools); +} + +llvm::json::Value toJSON(const TextContent &TC) { + return llvm::json::Object{{"type", "text"}, {"text", TC.text}}; +} + +bool fromJSON(const llvm::json::Value &V, TextContent &TC, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("text", TC.text); +} + +llvm::json::Value toJSON(const TextResult &TR) { + return llvm::json::Object{{"content", TR.content}, {"isError", TR.isError}}; +} + +bool fromJSON(const llvm::json::Value &V, TextResult &TR, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("content", TR.content) && O.map("isError", TR.isError); +} + +llvm::json::Value toJSON(const ToolDefinition &TD) { + llvm::json::Object Result{{"name", TD.name}}; + if (TD.description) + Result.insert({"description", TD.description}); + if (TD.inputSchema) + Result.insert({"inputSchema", TD.inputSchema}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ToolDefinition &TD, + llvm::json::Path P) { + + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("name", TD.name) || + !O.mapOptional("description", TD.description)) + return false; + return mapRaw(V, "inputSchema", TD.inputSchema, P); +} + +llvm::json::Value toJSON(const Message &M) { + return std::visit([](auto &M) { return toJSON(M); }, M); +} + +bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { + const auto *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + + if (const json::Value *V = O->get("jsonrpc")) { + if (V->getAsString().value_or("") != "2.0") { + P.report("unsupported JSON RPC version"); + return false; + } + } else { + P.report("not a valid JSON RPC message"); + return false; + } + + // A message without an ID is a Notification. + if (!O->get("id")) { + protocol::Notification N; + if (!fromJSON(V, N, P)) + return false; + M = std::move(N); + return true; + } + + if (O->get("error")) { + protocol::Error E; + if (!fromJSON(V, E, P)) + return false; + M = std::move(E); + return true; + } + + if (O->get("result")) { + protocol::Response R; + if (!fromJSON(V, R, P)) + return false; + M = std::move(R); + return true; + } + + if (O->get("method")) { + protocol::Request R; + if (!fromJSON(V, R, P)) + return false; + M = std::move(R); + return true; + } + + P.report("unrecognized message type"); + return false; +} + +} // namespace lldb_private::mcp::protocol diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.h b/lldb/source/Plugins/Protocol/MCP/Protocol.h new file mode 100644 index 0000000000000..e315899406573 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.h @@ -0,0 +1,128 @@ +//===- Protocol.h ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains POD structs based on the MCP specification at +// https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2024-11-05/schema.json +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOL_H +#define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOL_H + +#include "llvm/Support/JSON.h" +#include +#include +#include + +namespace lldb_private::mcp::protocol { + +static llvm::StringLiteral kVersion = "2024-11-05"; + +/// A request that expects a response. +struct Request { + uint64_t id = 0; + std::string method; + std::optional params; +}; + +llvm::json::Value toJSON(const Request &); +bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); + +struct ErrorInfo { + int64_t code = 0; + std::string message; + std::optional data; +}; + +llvm::json::Value toJSON(const ErrorInfo &); +bool fromJSON(const llvm::json::Value &, ErrorInfo &, llvm::json::Path); + +struct Error { + uint64_t id = 0; + ErrorInfo error; +}; + +llvm::json::Value toJSON(const Error &); +bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path); + +struct Response { + uint64_t id = 0; + std::optional result; + std::optional error; +}; + +llvm::json::Value toJSON(const Response &); +bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); + +/// A notification which does not expect a response. +struct Notification { + std::string method; + std::optional params; +}; + +llvm::json::Value toJSON(const Notification &); +bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path); + +struct ToolCapability { + /// Whether this server supports notifications for changes to the tool list. + bool listChanged = false; +}; + +llvm::json::Value toJSON(const ToolCapability &); +bool fromJSON(const llvm::json::Value &, ToolCapability &, llvm::json::Path); + +/// Capabilities that a server may support. Known capabilities are defined here, +/// in this schema, but this is not a closed set: any server can define its own, +/// additional capabilities. +struct Capabilities { + /// Present if the server offers any tools to call. + ToolCapability tools; +}; + +llvm::json::Value toJSON(const Capabilities &); +bool fromJSON(const llvm::json::Value &, Capabilities &, llvm::json::Path); + +/// Text provided to or from an LLM. +struct TextContent { + /// The text content of the message. + std::string text; +}; + +llvm::json::Value toJSON(const TextContent &); +bool fromJSON(const llvm::json::Value &, TextContent &, llvm::json::Path); + +struct TextResult { + std::vector content; + bool isError = false; +}; + +llvm::json::Value toJSON(const TextResult &); +bool fromJSON(const llvm::json::Value &, TextResult &, llvm::json::Path); + +struct ToolDefinition { + /// Unique identifier for the tool. + std::string name; + + /// Human-readable description. + std::optional description; + + // JSON Schema for the tool's parameters. + std::optional inputSchema; +}; + +llvm::json::Value toJSON(const ToolDefinition &); +bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path); + +using Message = std::variant; + +bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); +llvm::json::Value toJSON(const Message &); + +} // namespace lldb_private::mcp::protocol + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp new file mode 100644 index 0000000000000..029d4a887b0cc --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -0,0 +1,327 @@ +//===- ProtocolServerMCP.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ProtocolServerMCP.h" +#include "MCPError.h" +#include "lldb/Core/PluginManager.h" +#include "lldb/Utility/LLDBLog.h" +#include "lldb/Utility/Log.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Threading.h" +#include +#include + +using namespace lldb_private; +using namespace lldb_private::mcp; +using namespace llvm; + +LLDB_PLUGIN_DEFINE(ProtocolServerMCP) + +static constexpr size_t kChunkSize = 1024; + +ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger) + : ProtocolServer(), m_debugger(debugger) { + AddRequestHandler("initialize", + std::bind(&ProtocolServerMCP::InitializeHandler, this, + std::placeholders::_1)); + AddRequestHandler("tools/list", + std::bind(&ProtocolServerMCP::ToolsListHandler, this, + std::placeholders::_1)); + AddRequestHandler("tools/call", + std::bind(&ProtocolServerMCP::ToolsCallHandler, this, + std::placeholders::_1)); + AddNotificationHandler( + "notifications/initialized", [](const protocol::Notification &) { + LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete"); + }); + AddTool(std::make_unique( + "lldb_command", "Run an lldb command.", m_debugger)); +} + +ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } + +void ProtocolServerMCP::Initialize() { + PluginManager::RegisterPlugin(GetPluginNameStatic(), + GetPluginDescriptionStatic(), CreateInstance); +} + +void ProtocolServerMCP::Terminate() { + PluginManager::UnregisterPlugin(CreateInstance); +} + +lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) { + return std::make_shared(debugger); +} + +llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { + return "MCP Server."; +} + +llvm::Expected +ProtocolServerMCP::Handle(protocol::Request request) { + auto it = m_request_handlers.find(request.method); + if (it != m_request_handlers.end()) { + llvm::Expected response = it->second(request); + if (!response) + return response; + response->id = request.id; + return *response; + } + + return make_error( + llvm::formatv("no handler for request: {0}", request.method).str(), 1); +} + +void ProtocolServerMCP::Handle(protocol::Notification notification) { + auto it = m_notification_handlers.find(notification.method); + if (it != m_notification_handlers.end()) { + it->second(notification); + return; + } + + LLDB_LOG(GetLog(LLDBLog::Host), "MPC notification: {0} ({1})", + notification.method, notification.params); +} + +void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { + LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", + m_clients.size() + 1); + + lldb::IOObjectSP io_sp = std::move(socket); + auto client_up = std::make_unique(); + client_up->io_sp = io_sp; + Client *client = client_up.get(); + + Status status; + auto read_handle_up = m_loop.RegisterReadObject( + io_sp, + [this, client](MainLoopBase &loop) { + if (Error error = ReadCallback(*client)) { + LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); + client->read_handle_up.reset(); + } + }, + status); + if (status.Fail()) + return; + + client_up->read_handle_up = std::move(read_handle_up); + m_clients.emplace_back(std::move(client_up)); +} + +llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { + char chunk[kChunkSize]; + size_t bytes_read = sizeof(chunk); + if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail()) + return status.takeError(); + client.buffer.append(chunk, bytes_read); + + for (std::string::size_type pos; + (pos = client.buffer.find('\n')) != std::string::npos;) { + llvm::Expected> message = + HandleData(StringRef(client.buffer.data(), pos)); + client.buffer = client.buffer.erase(0, pos + 1); + if (!message) + return message.takeError(); + + if (*message) { + std::string Output; + llvm::raw_string_ostream OS(Output); + OS << llvm::formatv("{0}", toJSON(**message)) << '\n'; + size_t num_bytes = Output.size(); + return client.io_sp->Write(Output.data(), num_bytes).takeError(); + } + } + + return llvm::Error::success(); +} + +llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { + std::lock_guard guard(m_server_mutex); + + if (m_running) + return llvm::createStringError("server already running"); + + Status status; + m_listener = Socket::Create(connection.protocol, status); + if (status.Fail()) + return status.takeError(); + + status = m_listener->Listen(connection.name, /*backlog=*/5); + if (status.Fail()) + return status.takeError(); + + std::string address = + llvm::join(m_listener->GetListeningConnectionURI(), ", "); + auto handles = + m_listener->Accept(m_loop, std::bind(&ProtocolServerMCP::AcceptCallback, + this, std::placeholders::_1)); + if (llvm::Error error = handles.takeError()) + return error; + + m_listen_handlers = std::move(*handles); + m_loop_thread = std::thread([=] { + llvm::set_thread_name( + llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID())); + m_loop.Run(); + }); + + return llvm::Error::success(); +} + +llvm::Error ProtocolServerMCP::Stop() { + { + std::lock_guard guard(m_server_mutex); + m_running = false; + } + + // Stop the main loop. + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + + // Wait for the main loop to exit. + if (m_loop_thread.joinable()) + m_loop_thread.join(); + + { + std::lock_guard guard(m_server_mutex); + m_listener.reset(); + m_listen_handlers.clear(); + m_clients.clear(); + } + + return llvm::Error::success(); +} + +llvm::Expected> +ProtocolServerMCP::HandleData(llvm::StringRef data) { + auto message = llvm::json::parse(/*JSON=*/data); + if (!message) + return message.takeError(); + + if (const protocol::Request *request = + std::get_if(&(*message))) { + llvm::Expected response = Handle(*request); + + // Handle failures by converting them into an Error message. + if (!response) { + protocol::Error protocol_error; + llvm::handleAllErrors( + response.takeError(), + [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.error.code = -1; + protocol_error.error.message = err.message(); + }); + protocol_error.id = request->id; + return protocol_error; + } + + return *response; + } + + if (const protocol::Notification *notification = + std::get_if(&(*message))) { + Handle(*notification); + return std::nullopt; + } + + if (std::get_if(&(*message))) + return llvm::createStringError("unexpected MCP message: error"); + + if (std::get_if(&(*message))) + return llvm::createStringError("unexpected MCP message: response"); + + llvm_unreachable("all message types handled"); +} + +protocol::Capabilities ProtocolServerMCP::GetCapabilities() { + protocol::Capabilities capabilities; + capabilities.tools.listChanged = true; + return capabilities; +} + +void ProtocolServerMCP::AddTool(std::unique_ptr tool) { + std::lock_guard guard(m_server_mutex); + + if (!tool) + return; + m_tools[tool->GetName()] = std::move(tool); +} + +void ProtocolServerMCP::AddRequestHandler(llvm::StringRef method, + RequestHandler handler) { + std::lock_guard guard(m_server_mutex); + m_request_handlers[method] = std::move(handler); +} + +void ProtocolServerMCP::AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler) { + std::lock_guard guard(m_server_mutex); + m_notification_handlers[method] = std::move(handler); +} + +llvm::Expected +ProtocolServerMCP::InitializeHandler(const protocol::Request &request) { + protocol::Response response; + response.result.emplace(llvm::json::Object{ + {"protocolVersion", protocol::kVersion}, + {"capabilities", GetCapabilities()}, + {"serverInfo", + llvm::json::Object{{"name", kName}, {"version", kVersion}}}}); + return response; +} + +llvm::Expected +ProtocolServerMCP::ToolsListHandler(const protocol::Request &request) { + protocol::Response response; + + llvm::json::Array tools; + for (const auto &tool : m_tools) + tools.emplace_back(toJSON(tool.second->GetDefinition())); + + response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); + + return response; +} + +llvm::Expected +ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { + protocol::Response response; + + if (!request.params) + return llvm::createStringError("no tool parameters"); + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return llvm::createStringError("no tool parameters"); + + const json::Value *name = param_obj->get("name"); + if (!name) + return llvm::createStringError("no tool name"); + + llvm::StringRef tool_name = name->getAsString().value_or(""); + if (tool_name.empty()) + return llvm::createStringError("no tool name"); + + auto it = m_tools.find(tool_name); + if (it == m_tools.end()) + return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); + + const json::Value *args = param_obj->get("arguments"); + if (!args) + return llvm::createStringError("no tool arguments"); + + llvm::Expected text_result = it->second->Call(*args); + if (!text_result) + return text_result.takeError(); + + response.result.emplace(toJSON(*text_result)); + + return response; +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h new file mode 100644 index 0000000000000..52bb92a04a802 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -0,0 +1,100 @@ +//===- ProtocolServerMCP.h ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H +#define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H + +#include "Protocol.h" +#include "Tool.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/Socket.h" +#include "llvm/ADT/StringMap.h" +#include + +namespace lldb_private::mcp { + +class ProtocolServerMCP : public ProtocolServer { +public: + ProtocolServerMCP(Debugger &debugger); + virtual ~ProtocolServerMCP() override; + + virtual llvm::Error Start(ProtocolServer::Connection connection) override; + virtual llvm::Error Stop() override; + + static void Initialize(); + static void Terminate(); + + static llvm::StringRef GetPluginNameStatic() { return "MCP"; } + static llvm::StringRef GetPluginDescriptionStatic(); + + static lldb::ProtocolServerSP CreateInstance(Debugger &debugger); + + llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); } + + Socket *GetSocket() const override { return m_listener.get(); } + +protected: + using RequestHandler = std::function( + const protocol::Request &)>; + using NotificationHandler = + std::function; + + void AddTool(std::unique_ptr tool); + void AddRequestHandler(llvm::StringRef method, RequestHandler handler); + void AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler); + +private: + void AcceptCallback(std::unique_ptr socket); + + llvm::Expected> + HandleData(llvm::StringRef data); + + llvm::Expected Handle(protocol::Request request); + void Handle(protocol::Notification notification); + + llvm::Expected + InitializeHandler(const protocol::Request &); + llvm::Expected + ToolsListHandler(const protocol::Request &); + llvm::Expected + ToolsCallHandler(const protocol::Request &); + + protocol::Capabilities GetCapabilities(); + + llvm::StringLiteral kName = "lldb-mcp"; + llvm::StringLiteral kVersion = "0.1.0"; + + Debugger &m_debugger; + + bool m_running = false; + + MainLoop m_loop; + std::thread m_loop_thread; + + std::unique_ptr m_listener; + std::vector m_listen_handlers; + + struct Client { + lldb::IOObjectSP io_sp; + MainLoopBase::ReadHandleUP read_handle_up; + std::string buffer; + }; + llvm::Error ReadCallback(Client &client); + std::vector> m_clients; + + std::mutex m_server_mutex; + llvm::StringMap> m_tools; + + llvm::StringMap m_request_handlers; + llvm::StringMap m_notification_handlers; +}; +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp new file mode 100644 index 0000000000000..de8fcc8f3cb4c --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -0,0 +1,81 @@ +//===- Tool.cpp -----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Tool.h" +#include "lldb/Interpreter/CommandInterpreter.h" +#include "lldb/Interpreter/CommandReturnObject.h" + +using namespace lldb_private::mcp; +using namespace llvm; + +struct LLDBCommandToolArguments { + std::string arguments; +}; + +bool fromJSON(const llvm::json::Value &V, LLDBCommandToolArguments &A, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("arguments", A.arguments); +} + +Tool::Tool(std::string name, std::string description) + : m_name(std::move(name)), m_description(std::move(description)) {} + +protocol::ToolDefinition Tool::GetDefinition() const { + protocol::ToolDefinition definition; + definition.name = m_name; + definition.description.emplace(m_description); + + if (std::optional input_schema = GetSchema()) + definition.inputSchema = *input_schema; + + return definition; +} + +LLDBCommandTool::LLDBCommandTool(std::string name, std::string description, + Debugger &debugger) + : Tool(std::move(name), std::move(description)), m_debugger(debugger) {} + +llvm::Expected +LLDBCommandTool::Call(const llvm::json::Value &args) { + llvm::json::Path::Root root; + + LLDBCommandToolArguments arguments; + if (!fromJSON(args, arguments, root)) + return root.getError(); + + // FIXME: Disallow certain commands and their aliases. + CommandReturnObject result(/*colors=*/false); + m_debugger.GetCommandInterpreter().HandleCommand(arguments.arguments.c_str(), + eLazyBoolYes, result); + + std::string output; + llvm::StringRef output_str = result.GetOutputString(); + if (!output_str.empty()) + output += output_str.str(); + + std::string err_str = result.GetErrorString(); + if (!err_str.empty()) { + if (!output.empty()) + output += '\n'; + output += err_str; + } + + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{output}}); + text_result.isError = !result.Succeeded(); + return text_result; +} + +std::optional LLDBCommandTool::GetSchema() const { + llvm::json::Object str_type{{"type", "string"}}; + llvm::json::Object properties{{"arguments", std::move(str_type)}}; + llvm::json::Object schema{{"type", "object"}, + {"properties", std::move(properties)}}; + return schema; +} diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h new file mode 100644 index 0000000000000..57a5125813b76 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -0,0 +1,56 @@ +//===- Tool.h -------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H +#define LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H + +#include "Protocol.h" +#include "lldb/Core/Debugger.h" +#include "llvm/Support/JSON.h" +#include + +namespace lldb_private::mcp { + +class Tool { +public: + Tool(std::string name, std::string description); + virtual ~Tool() = default; + + virtual llvm::Expected + Call(const llvm::json::Value &args) = 0; + + virtual std::optional GetSchema() const { + return std::nullopt; + } + + protocol::ToolDefinition GetDefinition() const; + + const std::string &GetName() { return m_name; } + +private: + std::string m_name; + std::string m_description; +}; + +class LLDBCommandTool : public mcp::Tool { +public: + LLDBCommandTool(std::string name, std::string description, + Debugger &debugger); + ~LLDBCommandTool() = default; + + virtual llvm::Expected + Call(const llvm::json::Value &args) override; + + virtual std::optional GetSchema() const override; + +private: + Debugger &m_debugger; +}; +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/unittests/CMakeLists.txt b/lldb/unittests/CMakeLists.txt index 926f8a8602472..95af91ea05883 100644 --- a/lldb/unittests/CMakeLists.txt +++ b/lldb/unittests/CMakeLists.txt @@ -84,6 +84,10 @@ add_subdirectory(Utility) add_subdirectory(Thread) add_subdirectory(ValueObject) +if(LLDB_ENABLE_PROTOCOL_SERVERS) + add_subdirectory(Protocol) +endif() + if(LLDB_CAN_USE_DEBUGSERVER AND LLDB_TOOL_DEBUGSERVER_BUILD AND NOT LLDB_USE_SYSTEM_DEBUGSERVER) add_subdirectory(debugserver) endif() diff --git a/lldb/unittests/Protocol/CMakeLists.txt b/lldb/unittests/Protocol/CMakeLists.txt new file mode 100644 index 0000000000000..801662b0544d8 --- /dev/null +++ b/lldb/unittests/Protocol/CMakeLists.txt @@ -0,0 +1,12 @@ +add_lldb_unittest(ProtocolTests + ProtocolMCPTest.cpp + ProtocolMCPServerTest.cpp + + LINK_LIBS + lldbCore + lldbUtility + lldbHost + lldbPluginPlatformMacOSX + lldbPluginProtocolServerMCP + LLVMTestingSupport + ) diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp new file mode 100644 index 0000000000000..72b8c7b1fd825 --- /dev/null +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -0,0 +1,291 @@ +//===-- ProtocolServerMCPTest.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Plugins/Platform/MacOSX/PlatformRemoteMacOSX.h" +#include "Plugins/Protocol/MCP/ProtocolServerMCP.h" +#include "TestingSupport/Host/SocketTestUtilities.h" +#include "TestingSupport/SubsystemRAII.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/FileSystem.h" +#include "lldb/Host/HostInfo.h" +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/Socket.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; +using namespace lldb_private::mcp::protocol; + +namespace { +class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { +public: + using ProtocolServerMCP::AddNotificationHandler; + using ProtocolServerMCP::AddRequestHandler; + using ProtocolServerMCP::AddTool; + using ProtocolServerMCP::GetSocket; + using ProtocolServerMCP::ProtocolServerMCP; +}; + +class TestJSONTransport : public lldb_private::JSONRPCTransport { +public: + using JSONRPCTransport::JSONRPCTransport; + using JSONRPCTransport::ReadImpl; + using JSONRPCTransport::WriteImpl; +}; + +/// Test tool that returns it argument as text. +class TestTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + + virtual llvm::Expected + Call(const llvm::json::Value &args) override { + std::string argument; + if (const json::Object *args_obj = args.getAsObject()) { + if (const json::Value *s = args_obj->get("arguments")) { + argument = s->getAsString().value_or(""); + } + } + + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{argument}}); + return text_result; + } +}; + +/// Test tool that returns an error. +class ErrorTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + + virtual llvm::Expected + Call(const llvm::json::Value &args) override { + return llvm::createStringError("error"); + } +}; + +/// Test tool that fails but doesn't return an error. +class FailTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + + virtual llvm::Expected + Call(const llvm::json::Value &args) override { + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{"failed"}}); + text_result.isError = true; + return text_result; + } +}; + +class ProtocolServerMCPTest : public ::testing::Test { +public: + SubsystemRAII subsystems; + DebuggerSP m_debugger_sp; + + lldb::IOObjectSP m_io_sp; + std::unique_ptr m_transport_up; + std::unique_ptr m_server_up; + + static constexpr llvm::StringLiteral k_localhost = "localhost"; + + llvm::Error Write(llvm::StringRef message) { + return m_transport_up->WriteImpl(llvm::formatv("{0}\n", message).str()); + } + + llvm::Expected Read() { + return m_transport_up->ReadImpl(std::chrono::milliseconds(100)); + } + + void SetUp() { + // Create a debugger. + ArchSpec arch("arm64-apple-macosx-"); + Platform::SetHostPlatform( + PlatformRemoteMacOSX::CreateInstance(true, &arch)); + m_debugger_sp = Debugger::CreateInstance(); + + // Create & start the server. + ProtocolServer::Connection connection; + connection.protocol = Socket::SocketProtocol::ProtocolTcp; + connection.name = llvm::formatv("{0}:0", k_localhost).str(); + m_server_up = std::make_unique(*m_debugger_sp); + m_server_up->AddTool(std::make_unique("test", "test tool")); + ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); + + // Connect to the server over a TCP socket. + auto connect_socket_up = std::make_unique(true); + ASSERT_THAT_ERROR(connect_socket_up + ->Connect(llvm::formatv("{0}:{1}", k_localhost, + static_cast( + m_server_up->GetSocket()) + ->GetLocalPortNumber()) + .str()) + .ToError(), + llvm::Succeeded()); + + // Set up JSON transport for the client. + m_io_sp = std::move(connect_socket_up); + m_transport_up = std::make_unique(m_io_sp, m_io_sp); + } + + void TearDown() { + // Stop the server. + ASSERT_THAT_ERROR(m_server_up->Stop(), llvm::Succeeded()); + } +}; + +} // namespace + +TEST_F(ProtocolServerMCPTest, Intialization) { + llvm::StringLiteral request = + R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"claude-ai","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; + llvm::StringLiteral response = + R"json({"jsonrpc":"2.0","id":0,"result":{"capabilities":{"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsList) { + llvm::StringLiteral request = + R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; + llvm::StringLiteral response = + R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"}},"type":"object"},"name":"lldb_command"}]}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ResourcesList) { + llvm::StringLiteral request = + R"json({"method":"resources/list","params":{},"jsonrpc":"2.0","id":2})json"; + llvm::StringLiteral response = + R"json({"error":{"code":1,"message":"no handler for request: resources/list"},"id":2,"jsonrpc":"2.0"})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsCall) { + llvm::StringLiteral request = + R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + llvm::StringLiteral response = + R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsCallError) { + m_server_up->AddTool(std::make_unique("error", "error tool")); + + llvm::StringLiteral request = + R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + llvm::StringLiteral response = + R"json({"error":{"code":-1,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsCallFail) { + m_server_up->AddTool(std::make_unique("fail", "fail tool")); + + llvm::StringLiteral request = + R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + llvm::StringLiteral response = + R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, NotificationInitialized) { + bool handler_called = false; + std::condition_variable cv; + std::mutex mutex; + + m_server_up->AddNotificationHandler( + "notifications/initialized", + [&](const mcp::protocol::Notification ¬ification) { + { + std::lock_guard lock(mutex); + handler_called = true; + } + cv.notify_all(); + }); + llvm::StringLiteral request = + R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + std::unique_lock lock(mutex); + cv.wait(lock, [&] { return handler_called; }); +} diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp new file mode 100644 index 0000000000000..00959f3ce20be --- /dev/null +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -0,0 +1,135 @@ +//===-- ProtocolMCPTest.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Plugins/Protocol/MCP/Protocol.h" +#include "TestingSupport/TestUtilities.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +using namespace lldb; +using namespace lldb_private; +using namespace lldb_private::mcp::protocol; + +TEST(ProtocolMCPTest, Request) { + Request request; + request.id = 1; + request.method = "foo"; + request.params = llvm::json::Object{{"key", "value"}}; + + llvm::Expected deserialized_request = roundtripJSON(request); + ASSERT_THAT_EXPECTED(deserialized_request, llvm::Succeeded()); + + EXPECT_EQ(request.id, deserialized_request->id); + EXPECT_EQ(request.method, deserialized_request->method); + EXPECT_EQ(request.params, deserialized_request->params); +} + +TEST(ProtocolMCPTest, Response) { + Response response; + response.id = 1; + response.result = llvm::json::Object{{"key", "value"}}; + + llvm::Expected deserialized_response = roundtripJSON(response); + ASSERT_THAT_EXPECTED(deserialized_response, llvm::Succeeded()); + + EXPECT_EQ(response.id, deserialized_response->id); + EXPECT_EQ(response.result, deserialized_response->result); +} + +TEST(ProtocolMCPTest, Notification) { + Notification notification; + notification.method = "notifyMethod"; + notification.params = llvm::json::Object{{"key", "value"}}; + + llvm::Expected deserialized_notification = + roundtripJSON(notification); + ASSERT_THAT_EXPECTED(deserialized_notification, llvm::Succeeded()); + + EXPECT_EQ(notification.method, deserialized_notification->method); + EXPECT_EQ(notification.params, deserialized_notification->params); +} + +TEST(ProtocolMCPTest, ToolCapability) { + ToolCapability tool_capability; + tool_capability.listChanged = true; + + llvm::Expected deserialized_tool_capability = + roundtripJSON(tool_capability); + ASSERT_THAT_EXPECTED(deserialized_tool_capability, llvm::Succeeded()); + + EXPECT_EQ(tool_capability.listChanged, + deserialized_tool_capability->listChanged); +} + +TEST(ProtocolMCPTest, Capabilities) { + ToolCapability tool_capability; + tool_capability.listChanged = true; + + Capabilities capabilities; + capabilities.tools = tool_capability; + + llvm::Expected deserialized_capabilities = + roundtripJSON(capabilities); + ASSERT_THAT_EXPECTED(deserialized_capabilities, llvm::Succeeded()); + + EXPECT_EQ(capabilities.tools.listChanged, + deserialized_capabilities->tools.listChanged); +} + +TEST(ProtocolMCPTest, TextContent) { + TextContent text_content; + text_content.text = "Sample text"; + + llvm::Expected deserialized_text_content = + roundtripJSON(text_content); + ASSERT_THAT_EXPECTED(deserialized_text_content, llvm::Succeeded()); + + EXPECT_EQ(text_content.text, deserialized_text_content->text); +} + +TEST(ProtocolMCPTest, TextResult) { + TextContent text_content1; + text_content1.text = "Text 1"; + + TextContent text_content2; + text_content2.text = "Text 2"; + + TextResult text_result; + text_result.content = {text_content1, text_content2}; + text_result.isError = true; + + llvm::Expected deserialized_text_result = + roundtripJSON(text_result); + ASSERT_THAT_EXPECTED(deserialized_text_result, llvm::Succeeded()); + + EXPECT_EQ(text_result.isError, deserialized_text_result->isError); + ASSERT_EQ(text_result.content.size(), + deserialized_text_result->content.size()); + EXPECT_EQ(text_result.content[0].text, + deserialized_text_result->content[0].text); + EXPECT_EQ(text_result.content[1].text, + deserialized_text_result->content[1].text); +} + +TEST(ProtocolMCPTest, ToolDefinition) { + ToolDefinition tool_definition; + tool_definition.name = "ToolName"; + tool_definition.description = "Tool Description"; + tool_definition.inputSchema = + llvm::json::Object{{"schemaKey", "schemaValue"}}; + + llvm::Expected deserialized_tool_definition = + roundtripJSON(tool_definition); + ASSERT_THAT_EXPECTED(deserialized_tool_definition, llvm::Succeeded()); + + EXPECT_EQ(tool_definition.name, deserialized_tool_definition->name); + EXPECT_EQ(tool_definition.description, + deserialized_tool_definition->description); + EXPECT_EQ(tool_definition.inputSchema, + deserialized_tool_definition->inputSchema); +} diff --git a/lldb/unittests/TestingSupport/TestUtilities.h b/lldb/unittests/TestingSupport/TestUtilities.h index 7d040d64db8d8..a8bdda6ad33ae 100644 --- a/lldb/unittests/TestingSupport/TestUtilities.h +++ b/lldb/unittests/TestingSupport/TestUtilities.h @@ -56,6 +56,15 @@ class TestFile { std::string Buffer; }; + +template static llvm::Expected roundtripJSON(const T &input) { + llvm::json::Value value = toJSON(input); + llvm::json::Path::Root root; + T output; + if (!fromJSON(value, output, root)) + return root.getError(); + return output; +} } // namespace lldb_private #endif From 94b5a261c14aad8f4d6d3acbaf332587254fcc5e Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Thu, 19 Jun 2025 20:48:07 -0500 Subject: [PATCH 06/15] [lldb-dap] Make connection URLs match lldb (#144770) Use the same scheme as ConnectionFileDescriptor::Connect and use "listen" and "accept". Addresses feedback from a Pavel in a different PR [1]. [1] https://github.com/llvm/llvm-project/pull/143628#discussion_r2152225200 (cherry picked from commit 4f991cc99523e4bb7a0d96cee9f5c3a64bf2bc8e) --- lldb/include/lldb/Host/Socket.h | 9 +++++++++ lldb/source/Host/common/Socket.cpp | 32 +++++++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/lldb/include/lldb/Host/Socket.h b/lldb/include/lldb/Host/Socket.h index c982b6d9af9a1..59de18424d7c8 100644 --- a/lldb/include/lldb/Host/Socket.h +++ b/lldb/include/lldb/Host/Socket.h @@ -73,6 +73,11 @@ class Socket : public IOObject { ProtocolUnixAbstract }; + enum SocketMode { + ModeAccept, + ModeConnect, + }; + struct HostAndPort { std::string hostname; uint16_t port; @@ -82,6 +87,10 @@ class Socket : public IOObject { } }; + using ProtocolModePair = std::pair; + static std::optional + GetProtocolAndMode(llvm::StringRef scheme); + static const NativeSocket kInvalidSocketValue; ~Socket() override; diff --git a/lldb/source/Host/common/Socket.cpp b/lldb/source/Host/common/Socket.cpp index 2ed6e30cc1566..30a356f034803 100644 --- a/lldb/source/Host/common/Socket.cpp +++ b/lldb/source/Host/common/Socket.cpp @@ -295,7 +295,8 @@ Socket::UdpConnect(llvm::StringRef host_and_port, return UDPSocket::Connect(host_and_port, child_processes_inherit); } -llvm::Expected Socket::DecodeHostAndPort(llvm::StringRef host_and_port) { +llvm::Expected +Socket::DecodeHostAndPort(llvm::StringRef host_and_port) { static llvm::Regex g_regex("([^:]+|\\[[0-9a-fA-F:]+.*\\]):([0-9]+)"); HostAndPort ret; llvm::SmallVector matches; @@ -371,8 +372,8 @@ Status Socket::Write(const void *buf, size_t &num_bytes) { ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64 " (error = %s)", static_cast(this), static_cast(m_socket), buf, - static_cast(src_len), - static_cast(bytes_sent), error.AsCString()); + static_cast(src_len), static_cast(bytes_sent), + error.AsCString()); } return error; @@ -497,3 +498,28 @@ llvm::raw_ostream &lldb_private::operator<<(llvm::raw_ostream &OS, const Socket::HostAndPort &HP) { return OS << '[' << HP.hostname << ']' << ':' << HP.port; } + +std::optional +Socket::GetProtocolAndMode(llvm::StringRef scheme) { + // Keep in sync with ConnectionFileDescriptor::Connect. + return llvm::StringSwitch>(scheme) + .Case("listen", ProtocolModePair{SocketProtocol::ProtocolTcp, + SocketMode::ModeAccept}) + .Cases("accept", "unix-accept", + ProtocolModePair{SocketProtocol::ProtocolUnixDomain, + SocketMode::ModeAccept}) + .Case("unix-abstract-accept", + ProtocolModePair{SocketProtocol::ProtocolUnixAbstract, + SocketMode::ModeAccept}) + .Cases("connect", "tcp-connect", + ProtocolModePair{SocketProtocol::ProtocolTcp, + SocketMode::ModeConnect}) + .Case("udp", ProtocolModePair{SocketProtocol::ProtocolTcp, + SocketMode::ModeConnect}) + .Case("unix-connect", ProtocolModePair{SocketProtocol::ProtocolUnixDomain, + SocketMode::ModeConnect}) + .Case("unix-abstract-connect", + ProtocolModePair{SocketProtocol::ProtocolUnixAbstract, + SocketMode::ModeConnect}) + .Default(std::nullopt); +} From c7806a822336f37cf1c5ba1ed3c5796190bed757 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Wed, 30 Jul 2025 20:04:15 -0700 Subject: [PATCH 07/15] Adjust for older API interface --- lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp | 2 +- lldb/unittests/Protocol/ProtocolMCPServerTest.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index 029d4a887b0cc..51caa9a0bd599 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -148,7 +148,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { return llvm::createStringError("server already running"); Status status; - m_listener = Socket::Create(connection.protocol, status); + m_listener = Socket::Create(connection.protocol, false, status); if (status.Fail()) return status.takeError(); diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index 72b8c7b1fd825..ae1f71c0ffee9 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -120,7 +120,7 @@ class ProtocolServerMCPTest : public ::testing::Test { ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); // Connect to the server over a TCP socket. - auto connect_socket_up = std::make_unique(true); + auto connect_socket_up = std::make_unique(true, false); ASSERT_THAT_ERROR(connect_socket_up ->Connect(llvm::formatv("{0}:{1}", k_localhost, static_cast( From 78ea96a0612fd48448142935bec03c64d41a1cd6 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Tue, 1 Jul 2025 15:56:22 -0700 Subject: [PATCH 08/15] [lldb] Fix PipeTest name collision in unit tests We had two classes named `PipeTest`: one in `PipeTestUtilities.h` and one in `PipeTest.cpp`. The latter was unintentionally using the wrong class (from the header) which didn't initialize the HostInfo subsystem. This resulted in a crash due to a nullptr dereference (`g_fields`) when `PipePosix::CreateWithUniqueName` called `HostInfoBase::GetProcessTempDir`. (cherry picked from commit e89458d3985c1b612b8a64914c887a3ce3dd3509) --- lldb/unittests/DAP/TestBase.cpp | 2 +- lldb/unittests/Host/JSONTransportTest.cpp | 4 ++-- lldb/unittests/TestingSupport/Host/PipeTestUtilities.h | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index 27ad42686fbbf..d5d36158d68e0 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -29,7 +29,7 @@ using lldb_private::NativeFile; using lldb_private::Pipe; void TransportBase::SetUp() { - PipeTest::SetUp(); + PipePairTest::SetUp(); to_dap = std::make_unique( "to_dap", nullptr, std::make_shared(input.GetReadFileDescriptor(), diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index f1ec5e03bbeca..d54d121500be0 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -14,12 +14,12 @@ using namespace llvm; using namespace lldb_private; namespace { -template class JSONTransportTest : public PipeTest { +template class JSONTransportTest : public PipePairTest { protected: std::unique_ptr transport; void SetUp() override { - PipeTest::SetUp(); + PipePairTest::SetUp(); transport = std::make_unique( std::make_shared(input.GetReadFileDescriptor(), File::eOpenOptionReadOnly, diff --git a/lldb/unittests/TestingSupport/Host/PipeTestUtilities.h b/lldb/unittests/TestingSupport/Host/PipeTestUtilities.h index 50d5d4117c898..87a85ad77e65d 100644 --- a/lldb/unittests/TestingSupport/Host/PipeTestUtilities.h +++ b/lldb/unittests/TestingSupport/Host/PipeTestUtilities.h @@ -14,7 +14,7 @@ #include "gtest/gtest.h" /// A base class for tests that need a pair of pipes for communication. -class PipeTest : public testing::Test { +class PipePairTest : public testing::Test { protected: lldb_private::Pipe input; lldb_private::Pipe output; From a565926b9a2be2143941aebcc90a106b914a0557 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Tue, 24 Jun 2025 13:31:59 -0700 Subject: [PATCH 09/15] [lldb] Add more tests MCP protocol types Add unit testing for the different message types. (cherry picked from commit 7e3af676312ba9716f05600e47a6f5897307c4ff) --- lldb/unittests/Protocol/ProtocolMCPTest.cpp | 97 +++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp index 00959f3ce20be..14cc240dd3628 100644 --- a/lldb/unittests/Protocol/ProtocolMCPTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -133,3 +133,100 @@ TEST(ProtocolMCPTest, ToolDefinition) { EXPECT_EQ(tool_definition.inputSchema, deserialized_tool_definition->inputSchema); } + +TEST(ProtocolMCPTest, MessageWithRequest) { + Request request; + request.id = 1; + request.method = "test_method"; + request.params = llvm::json::Object{{"param", "value"}}; + + Message message = request; + + llvm::Expected deserialized_message = roundtripJSON(message); + ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded()); + + ASSERT_TRUE(std::holds_alternative(*deserialized_message)); + const Request &deserialized_request = + std::get(*deserialized_message); + + EXPECT_EQ(request.id, deserialized_request.id); + EXPECT_EQ(request.method, deserialized_request.method); + EXPECT_EQ(request.params, deserialized_request.params); +} + +TEST(ProtocolMCPTest, MessageWithResponse) { + Response response; + response.id = 2; + response.result = llvm::json::Object{{"result", "success"}}; + + Message message = response; + + llvm::Expected deserialized_message = roundtripJSON(message); + ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded()); + + ASSERT_TRUE(std::holds_alternative(*deserialized_message)); + const Response &deserialized_response = + std::get(*deserialized_message); + + EXPECT_EQ(response.id, deserialized_response.id); + EXPECT_EQ(response.result, deserialized_response.result); +} + +TEST(ProtocolMCPTest, MessageWithNotification) { + Notification notification; + notification.method = "notification_method"; + notification.params = llvm::json::Object{{"notify", "data"}}; + + Message message = notification; + + llvm::Expected deserialized_message = roundtripJSON(message); + ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded()); + + ASSERT_TRUE(std::holds_alternative(*deserialized_message)); + const Notification &deserialized_notification = + std::get(*deserialized_message); + + EXPECT_EQ(notification.method, deserialized_notification.method); + EXPECT_EQ(notification.params, deserialized_notification.params); +} + +TEST(ProtocolMCPTest, MessageWithError) { + ErrorInfo error_info; + error_info.code = -32603; + error_info.message = "Internal error"; + + Error error; + error.id = 3; + error.error = error_info; + + Message message = error; + + llvm::Expected deserialized_message = roundtripJSON(message); + ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded()); + + ASSERT_TRUE(std::holds_alternative(*deserialized_message)); + const Error &deserialized_error = std::get(*deserialized_message); + + EXPECT_EQ(error.id, deserialized_error.id); + EXPECT_EQ(error.error.code, deserialized_error.error.code); + EXPECT_EQ(error.error.message, deserialized_error.error.message); +} + +TEST(ProtocolMCPTest, ResponseWithError) { + ErrorInfo error_info; + error_info.code = -32700; + error_info.message = "Parse error"; + + Response response; + response.id = 4; + response.error = error_info; + + llvm::Expected deserialized_response = roundtripJSON(response); + ASSERT_THAT_EXPECTED(deserialized_response, llvm::Succeeded()); + + EXPECT_EQ(response.id, deserialized_response->id); + EXPECT_FALSE(deserialized_response->result.has_value()); + ASSERT_TRUE(deserialized_response->error.has_value()); + EXPECT_EQ(response.error->code, deserialized_response->error->code); + EXPECT_EQ(response.error->message, deserialized_response->error->message); +} From 9140740283225ebf8f36d68e1267a6a3784f38c7 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Wed, 25 Jun 2025 15:46:33 -0500 Subject: [PATCH 10/15] [lldb] Make MCP server instance global (#145616) Rather than having one MCP server per debugger, make the MCP server global and pass a debugger id along with tool invocations that require one. This PR also adds a second tool to list the available debuggers with their targets so the model can decide which debugger instance to use. (cherry picked from commit e8abdfc88ffed632750fe0fd7deafb577e902bd6) --- lldb/include/lldb/Core/Debugger.h | 6 - lldb/include/lldb/Core/ProtocolServer.h | 5 +- lldb/include/lldb/lldb-forward.h | 2 +- lldb/include/lldb/lldb-private-interfaces.h | 3 +- .../Commands/CommandObjectProtocolServer.cpp | 51 ++------ lldb/source/Core/Debugger.cpp | 23 ---- lldb/source/Core/ProtocolServer.cpp | 34 +++++- lldb/source/Plugins/Protocol/MCP/Protocol.h | 2 + .../Protocol/MCP/ProtocolServerMCP.cpp | 30 ++--- .../Plugins/Protocol/MCP/ProtocolServerMCP.h | 6 +- lldb/source/Plugins/Protocol/MCP/Tool.cpp | 109 ++++++++++++++---- lldb/source/Plugins/Protocol/MCP/Tool.h | 24 ++-- .../Protocol/ProtocolMCPServerTest.cpp | 21 ++-- 13 files changed, 180 insertions(+), 136 deletions(-) diff --git a/lldb/include/lldb/Core/Debugger.h b/lldb/include/lldb/Core/Debugger.h index 625ecf2ed26fc..35a41e419c9bf 100644 --- a/lldb/include/lldb/Core/Debugger.h +++ b/lldb/include/lldb/Core/Debugger.h @@ -617,10 +617,6 @@ class Debugger : public std::enable_shared_from_this, void FlushProcessOutput(Process &process, bool flush_stdout, bool flush_stderr); - void AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp); - void RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp); - lldb::ProtocolServerSP GetProtocolServer(llvm::StringRef protocol) const; - SourceManager::SourceFileCache &GetSourceFileCache() { return m_source_file_cache; } @@ -793,8 +789,6 @@ class Debugger : public std::enable_shared_from_this, mutable std::mutex m_progress_reports_mutex; /// @} - llvm::SmallVector m_protocol_servers; - std::mutex m_destroy_callback_mutex; lldb::callback_token_t m_destroy_callback_next_token = 0; struct DestroyCallbackInfo { diff --git a/lldb/include/lldb/Core/ProtocolServer.h b/lldb/include/lldb/Core/ProtocolServer.h index fafe460904323..937256c10aec1 100644 --- a/lldb/include/lldb/Core/ProtocolServer.h +++ b/lldb/include/lldb/Core/ProtocolServer.h @@ -20,8 +20,9 @@ class ProtocolServer : public PluginInterface { ProtocolServer() = default; virtual ~ProtocolServer() = default; - static lldb::ProtocolServerSP Create(llvm::StringRef name, - Debugger &debugger); + static ProtocolServer *GetOrCreate(llvm::StringRef name); + + static std::vector GetSupportedProtocols(); struct Connection { Socket::SocketProtocol protocol; diff --git a/lldb/include/lldb/lldb-forward.h b/lldb/include/lldb/lldb-forward.h index cdcd95443cc7a..d0e7d5e8e2120 100644 --- a/lldb/include/lldb/lldb-forward.h +++ b/lldb/include/lldb/lldb-forward.h @@ -389,7 +389,7 @@ typedef std::shared_ptr PlatformSP; typedef std::shared_ptr ProcessSP; typedef std::shared_ptr ProcessAttachInfoSP; typedef std::shared_ptr ProcessLaunchInfoSP; -typedef std::shared_ptr ProtocolServerSP; +typedef std::unique_ptr ProtocolServerUP; typedef std::weak_ptr ProcessWP; typedef std::shared_ptr RegisterCheckpointSP; typedef std::shared_ptr RegisterContextSP; diff --git a/lldb/include/lldb/lldb-private-interfaces.h b/lldb/include/lldb/lldb-private-interfaces.h index 19ab5f435659b..6511269f32f37 100644 --- a/lldb/include/lldb/lldb-private-interfaces.h +++ b/lldb/include/lldb/lldb-private-interfaces.h @@ -82,8 +82,7 @@ typedef lldb::PlatformSP (*PlatformCreateInstance)(bool force, typedef lldb::ProcessSP (*ProcessCreateInstance)( lldb::TargetSP target_sp, lldb::ListenerSP listener_sp, const FileSpec *crash_file_path, bool can_connect); -typedef lldb::ProtocolServerSP (*ProtocolServerCreateInstance)( - Debugger &debugger); +typedef lldb::ProtocolServerUP (*ProtocolServerCreateInstance)(); typedef lldb::RegisterTypeBuilderSP (*RegisterTypeBuilderCreateInstance)( Target &target); typedef lldb::ScriptInterpreterSP (*ScriptInterpreterCreateInstance)( diff --git a/lldb/source/Commands/CommandObjectProtocolServer.cpp b/lldb/source/Commands/CommandObjectProtocolServer.cpp index 420fc5fdddadb..38d93cabf8c04 100644 --- a/lldb/source/Commands/CommandObjectProtocolServer.cpp +++ b/lldb/source/Commands/CommandObjectProtocolServer.cpp @@ -24,20 +24,6 @@ using namespace lldb_private; #define LLDB_OPTIONS_mcp #include "CommandOptions.inc" -static std::vector GetSupportedProtocols() { - std::vector supported_protocols; - size_t i = 0; - - for (llvm::StringRef protocol_name = - PluginManager::GetProtocolServerPluginNameAtIndex(i++); - !protocol_name.empty(); - protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) { - supported_protocols.push_back(protocol_name); - } - - return supported_protocols; -} - class CommandObjectProtocolServerStart : public CommandObjectParsed { public: CommandObjectProtocolServerStart(CommandInterpreter &interpreter) @@ -58,12 +44,11 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed { } llvm::StringRef protocol = args.GetArgumentAtIndex(0); - std::vector supported_protocols = GetSupportedProtocols(); - if (llvm::find(supported_protocols, protocol) == - supported_protocols.end()) { + ProtocolServer *server = ProtocolServer::GetOrCreate(protocol); + if (!server) { result.AppendErrorWithFormatv( "unsupported protocol: {0}. Supported protocols are: {1}", protocol, - llvm::join(GetSupportedProtocols(), ", ")); + llvm::join(ProtocolServer::GetSupportedProtocols(), ", ")); return; } @@ -73,10 +58,6 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed { } llvm::StringRef connection_uri = args.GetArgumentAtIndex(1); - ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol); - if (!server_sp) - server_sp = ProtocolServer::Create(protocol, GetDebugger()); - const char *connection_error = "unsupported connection specifier, expected 'accept:///path' or " "'listen://[host]:port', got '{0}'."; @@ -99,14 +80,12 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed { formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname, uri->port.value_or(0)); - if (llvm::Error error = server_sp->Start(connection)) { + if (llvm::Error error = server->Start(connection)) { result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); return; } - GetDebugger().AddProtocolServer(server_sp); - - if (Socket *socket = server_sp->GetSocket()) { + if (Socket *socket = server->GetSocket()) { std::string address = llvm::join(socket->GetListeningConnectionURI(), ", "); result.AppendMessageWithFormatv( @@ -135,30 +114,18 @@ class CommandObjectProtocolServerStop : public CommandObjectParsed { } llvm::StringRef protocol = args.GetArgumentAtIndex(0); - std::vector supported_protocols = GetSupportedProtocols(); - if (llvm::find(supported_protocols, protocol) == - supported_protocols.end()) { + ProtocolServer *server = ProtocolServer::GetOrCreate(protocol); + if (!server) { result.AppendErrorWithFormatv( "unsupported protocol: {0}. Supported protocols are: {1}", protocol, - llvm::join(GetSupportedProtocols(), ", ")); + llvm::join(ProtocolServer::GetSupportedProtocols(), ", ")); return; } - Debugger &debugger = GetDebugger(); - - ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol); - if (!server_sp) { - result.AppendError( - llvm::formatv("no {0} protocol server running", protocol).str()); - return; - } - - if (llvm::Error error = server_sp->Stop()) { + if (llvm::Error error = server->Stop()) { result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); return; } - - debugger.RemoveProtocolServer(server_sp); } }; diff --git a/lldb/source/Core/Debugger.cpp b/lldb/source/Core/Debugger.cpp index d8930ccf06d3b..bcafdb083ef3e 100644 --- a/lldb/source/Core/Debugger.cpp +++ b/lldb/source/Core/Debugger.cpp @@ -2380,26 +2380,3 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() { "Debugger::GetThreadPool called before Debugger::Initialize"); return *g_thread_pool; } - -void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) { - assert(protocol_server_sp && - GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr); - m_protocol_servers.push_back(protocol_server_sp); -} - -void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) { - auto it = llvm::find(m_protocol_servers, protocol_server_sp); - if (it != m_protocol_servers.end()) - m_protocol_servers.erase(it); -} - -lldb::ProtocolServerSP -Debugger::GetProtocolServer(llvm::StringRef protocol) const { - for (ProtocolServerSP protocol_server_sp : m_protocol_servers) { - if (!protocol_server_sp) - continue; - if (protocol_server_sp->GetPluginName() == protocol) - return protocol_server_sp; - } - return nullptr; -} diff --git a/lldb/source/Core/ProtocolServer.cpp b/lldb/source/Core/ProtocolServer.cpp index d57a047afa7b2..41636cdacdecc 100644 --- a/lldb/source/Core/ProtocolServer.cpp +++ b/lldb/source/Core/ProtocolServer.cpp @@ -12,10 +12,36 @@ using namespace lldb_private; using namespace lldb; -ProtocolServerSP ProtocolServer::Create(llvm::StringRef name, - Debugger &debugger) { +ProtocolServer *ProtocolServer::GetOrCreate(llvm::StringRef name) { + static std::mutex g_mutex; + static llvm::StringMap g_protocol_server_instances; + + std::lock_guard guard(g_mutex); + + auto it = g_protocol_server_instances.find(name); + if (it != g_protocol_server_instances.end()) + return it->second.get(); + if (ProtocolServerCreateInstance create_callback = - PluginManager::GetProtocolCreateCallbackForPluginName(name)) - return create_callback(debugger); + PluginManager::GetProtocolCreateCallbackForPluginName(name)) { + auto pair = + g_protocol_server_instances.try_emplace(name, create_callback()); + return pair.first->second.get(); + } + return nullptr; } + +std::vector ProtocolServer::GetSupportedProtocols() { + std::vector supported_protocols; + size_t i = 0; + + for (llvm::StringRef protocol_name = + PluginManager::GetProtocolServerPluginNameAtIndex(i++); + !protocol_name.empty(); + protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) { + supported_protocols.push_back(protocol_name); + } + + return supported_protocols; +} diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.h b/lldb/source/Plugins/Protocol/MCP/Protocol.h index e315899406573..cb790dc4e5596 100644 --- a/lldb/source/Plugins/Protocol/MCP/Protocol.h +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.h @@ -123,6 +123,8 @@ using Message = std::variant; bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); llvm::json::Value toJSON(const Message &); +using ToolArguments = std::variant; + } // namespace lldb_private::mcp::protocol #endif diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index 51caa9a0bd599..2797cedd566e3 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -24,8 +24,7 @@ LLDB_PLUGIN_DEFINE(ProtocolServerMCP) static constexpr size_t kChunkSize = 1024; -ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger) - : ProtocolServer(), m_debugger(debugger) { +ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() { AddRequestHandler("initialize", std::bind(&ProtocolServerMCP::InitializeHandler, this, std::placeholders::_1)); @@ -39,8 +38,10 @@ ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger) "notifications/initialized", [](const protocol::Notification &) { LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete"); }); - AddTool(std::make_unique( - "lldb_command", "Run an lldb command.", m_debugger)); + AddTool( + std::make_unique("lldb_command", "Run an lldb command.")); + AddTool(std::make_unique( + "lldb_debugger_list", "List debugger instances with their debugger_id.")); } ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } @@ -54,8 +55,8 @@ void ProtocolServerMCP::Terminate() { PluginManager::UnregisterPlugin(CreateInstance); } -lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) { - return std::make_shared(debugger); +lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() { + return std::make_unique(); } llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { @@ -145,7 +146,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { std::lock_guard guard(m_server_mutex); if (m_running) - return llvm::createStringError("server already running"); + return llvm::createStringError("the MCP server is already running"); Status status; m_listener = Socket::Create(connection.protocol, false, status); @@ -164,10 +165,10 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { if (llvm::Error error = handles.takeError()) return error; + m_running = true; m_listen_handlers = std::move(*handles); m_loop_thread = std::thread([=] { - llvm::set_thread_name( - llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID())); + llvm::set_thread_name("protocol-server.mcp"); m_loop.Run(); }); @@ -177,6 +178,8 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { llvm::Error ProtocolServerMCP::Stop() { { std::lock_guard guard(m_server_mutex); + if (!m_running) + return createStringError("the MCP sever is not running"); m_running = false; } @@ -313,11 +316,12 @@ ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { if (it == m_tools.end()) return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); - const json::Value *args = param_obj->get("arguments"); - if (!args) - return llvm::createStringError("no tool arguments"); + protocol::ToolArguments tool_args; + if (const json::Value *args = param_obj->get("arguments")) + tool_args = *args; - llvm::Expected text_result = it->second->Call(*args); + llvm::Expected text_result = + it->second->Call(tool_args); if (!text_result) return text_result.takeError(); diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 52bb92a04a802..d55882cc8ab09 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -21,7 +21,7 @@ namespace lldb_private::mcp { class ProtocolServerMCP : public ProtocolServer { public: - ProtocolServerMCP(Debugger &debugger); + ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; virtual llvm::Error Start(ProtocolServer::Connection connection) override; @@ -33,7 +33,7 @@ class ProtocolServerMCP : public ProtocolServer { static llvm::StringRef GetPluginNameStatic() { return "MCP"; } static llvm::StringRef GetPluginDescriptionStatic(); - static lldb::ProtocolServerSP CreateInstance(Debugger &debugger); + static lldb::ProtocolServerUP CreateInstance(); llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); } @@ -71,8 +71,6 @@ class ProtocolServerMCP : public ProtocolServer { llvm::StringLiteral kName = "lldb-mcp"; llvm::StringLiteral kVersion = "0.1.0"; - Debugger &m_debugger; - bool m_running = false; MainLoop m_loop; diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index de8fcc8f3cb4c..5c4626cf66b32 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -7,22 +7,38 @@ //===----------------------------------------------------------------------===// #include "Tool.h" +#include "lldb/Core/Module.h" #include "lldb/Interpreter/CommandInterpreter.h" #include "lldb/Interpreter/CommandReturnObject.h" using namespace lldb_private::mcp; using namespace llvm; -struct LLDBCommandToolArguments { +namespace { +struct CommandToolArguments { + uint64_t debugger_id; std::string arguments; }; -bool fromJSON(const llvm::json::Value &V, LLDBCommandToolArguments &A, +bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); - return O && O.map("arguments", A.arguments); + return O && O.map("debugger_id", A.debugger_id) && + O.mapOptional("arguments", A.arguments); } +/// Helper function to create a TextResult from a string output. +static lldb_private::mcp::protocol::TextResult +createTextResult(std::string output, bool is_error = false) { + lldb_private::mcp::protocol::TextResult text_result; + text_result.content.emplace_back( + lldb_private::mcp::protocol::TextContent{{std::move(output)}}); + text_result.isError = is_error; + return text_result; +} + +} // namespace + Tool::Tool(std::string name, std::string description) : m_name(std::move(name)), m_description(std::move(description)) {} @@ -37,22 +53,27 @@ protocol::ToolDefinition Tool::GetDefinition() const { return definition; } -LLDBCommandTool::LLDBCommandTool(std::string name, std::string description, - Debugger &debugger) - : Tool(std::move(name), std::move(description)), m_debugger(debugger) {} - llvm::Expected -LLDBCommandTool::Call(const llvm::json::Value &args) { - llvm::json::Path::Root root; +CommandTool::Call(const protocol::ToolArguments &args) { + if (!std::holds_alternative(args)) + return createStringError("CommandTool requires arguments"); + + json::Path::Root root; - LLDBCommandToolArguments arguments; - if (!fromJSON(args, arguments, root)) + CommandToolArguments arguments; + if (!fromJSON(std::get(args), arguments, root)) return root.getError(); + lldb::DebuggerSP debugger_sp = + Debugger::GetDebuggerAtIndex(arguments.debugger_id); + if (!debugger_sp) + return createStringError( + llvm::formatv("no debugger with id {0}", arguments.debugger_id)); + // FIXME: Disallow certain commands and their aliases. CommandReturnObject result(/*colors=*/false); - m_debugger.GetCommandInterpreter().HandleCommand(arguments.arguments.c_str(), - eLazyBoolYes, result); + debugger_sp->GetCommandInterpreter().HandleCommand( + arguments.arguments.c_str(), eLazyBoolYes, result); std::string output; llvm::StringRef output_str = result.GetOutputString(); @@ -66,16 +87,64 @@ LLDBCommandTool::Call(const llvm::json::Value &args) { output += err_str; } - mcp::protocol::TextResult text_result; - text_result.content.emplace_back(mcp::protocol::TextContent{{output}}); - text_result.isError = !result.Succeeded(); - return text_result; + return createTextResult(output, !result.Succeeded()); } -std::optional LLDBCommandTool::GetSchema() const { +std::optional CommandTool::GetSchema() const { + llvm::json::Object id_type{{"type", "number"}}; llvm::json::Object str_type{{"type", "string"}}; - llvm::json::Object properties{{"arguments", std::move(str_type)}}; + llvm::json::Object properties{{"debugger_id", std::move(id_type)}, + {"arguments", std::move(str_type)}}; + llvm::json::Array required{"debugger_id"}; llvm::json::Object schema{{"type", "object"}, - {"properties", std::move(properties)}}; + {"properties", std::move(properties)}, + {"required", std::move(required)}}; return schema; } + +llvm::Expected +DebuggerListTool::Call(const protocol::ToolArguments &args) { + if (!std::holds_alternative(args)) + return createStringError("DebuggerListTool takes no arguments"); + + llvm::json::Path::Root root; + + // Return a nested Markdown list with debuggers and target. + // Example output: + // + // - debugger 0 + // - target 0 /path/to/foo + // - target 1 + // - debugger 1 + // - target 0 /path/to/bar + // + // FIXME: Use Structured Content when we adopt protocol version 2025-06-18. + std::string output; + llvm::raw_string_ostream os(output); + + const size_t num_debuggers = Debugger::GetNumDebuggers(); + for (size_t i = 0; i < num_debuggers; ++i) { + lldb::DebuggerSP debugger_sp = Debugger::GetDebuggerAtIndex(i); + if (!debugger_sp) + continue; + + os << "- debugger " << i << '\n'; + + TargetList &target_list = debugger_sp->GetTargetList(); + const size_t num_targets = target_list.GetNumTargets(); + for (size_t j = 0; j < num_targets; ++j) { + lldb::TargetSP target_sp = target_list.GetTargetAtIndex(j); + if (!target_sp) + continue; + os << " - target " << j; + if (target_sp == target_list.GetSelectedTarget()) + os << " (selected)"; + // Append the module path if we have one. + if (Module *exe_module = target_sp->GetExecutableModulePointer()) + os << " " << exe_module->GetFileSpec().GetPath(); + os << '\n'; + } + } + + return createTextResult(output); +} diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h index 57a5125813b76..74ab04b472522 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.h +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -22,10 +22,10 @@ class Tool { virtual ~Tool() = default; virtual llvm::Expected - Call(const llvm::json::Value &args) = 0; + Call(const protocol::ToolArguments &args) = 0; virtual std::optional GetSchema() const { - return std::nullopt; + return llvm::json::Object{{"type", "object"}}; } protocol::ToolDefinition GetDefinition() const; @@ -37,20 +37,26 @@ class Tool { std::string m_description; }; -class LLDBCommandTool : public mcp::Tool { +class CommandTool : public mcp::Tool { public: - LLDBCommandTool(std::string name, std::string description, - Debugger &debugger); - ~LLDBCommandTool() = default; + using mcp::Tool::Tool; + ~CommandTool() = default; virtual llvm::Expected - Call(const llvm::json::Value &args) override; + Call(const protocol::ToolArguments &args) override; virtual std::optional GetSchema() const override; +}; -private: - Debugger &m_debugger; +class DebuggerListTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + ~DebuggerListTool() = default; + + virtual llvm::Expected + Call(const protocol::ToolArguments &args) override; }; + } // namespace lldb_private::mcp #endif diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index ae1f71c0ffee9..3aa144ed4e19f 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -46,9 +46,10 @@ class TestTool : public mcp::Tool { using mcp::Tool::Tool; virtual llvm::Expected - Call(const llvm::json::Value &args) override { + Call(const ToolArguments &args) override { std::string argument; - if (const json::Object *args_obj = args.getAsObject()) { + if (const json::Object *args_obj = + std::get(args).getAsObject()) { if (const json::Value *s = args_obj->get("arguments")) { argument = s->getAsString().value_or(""); } @@ -66,7 +67,7 @@ class ErrorTool : public mcp::Tool { using mcp::Tool::Tool; virtual llvm::Expected - Call(const llvm::json::Value &args) override { + Call(const ToolArguments &args) override { return llvm::createStringError("error"); } }; @@ -77,7 +78,7 @@ class FailTool : public mcp::Tool { using mcp::Tool::Tool; virtual llvm::Expected - Call(const llvm::json::Value &args) override { + Call(const ToolArguments &args) override { mcp::protocol::TextResult text_result; text_result.content.emplace_back(mcp::protocol::TextContent{{"failed"}}); text_result.isError = true; @@ -115,7 +116,7 @@ class ProtocolServerMCPTest : public ::testing::Test { ProtocolServer::Connection connection; connection.protocol = Socket::SocketProtocol::ProtocolTcp; connection.name = llvm::formatv("{0}:0", k_localhost).str(); - m_server_up = std::make_unique(*m_debugger_sp); + m_server_up = std::make_unique(); m_server_up->AddTool(std::make_unique("test", "test tool")); ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); @@ -145,7 +146,7 @@ class ProtocolServerMCPTest : public ::testing::Test { TEST_F(ProtocolServerMCPTest, Intialization) { llvm::StringLiteral request = - R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"claude-ai","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; + R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; llvm::StringLiteral response = R"json({"jsonrpc":"2.0","id":0,"result":{"capabilities":{"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; @@ -167,7 +168,7 @@ TEST_F(ProtocolServerMCPTest, ToolsList) { llvm::StringLiteral request = R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; llvm::StringLiteral response = - R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"}},"type":"object"},"name":"lldb_command"}]}})json"; + R"json( {"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","inputSchema":{"type":"object"},"name":"test"},{"description":"List debugger instances with their debugger_id.","inputSchema":{"type":"object"},"name":"lldb_debugger_list"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"},"debugger_id":{"type":"number"}},"required":["debugger_id"],"type":"object"},"name":"lldb_command"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); @@ -205,7 +206,7 @@ TEST_F(ProtocolServerMCPTest, ResourcesList) { TEST_F(ProtocolServerMCPTest, ToolsCall) { llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; llvm::StringLiteral response = R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; @@ -227,7 +228,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { m_server_up->AddTool(std::make_unique("error", "error tool")); llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; llvm::StringLiteral response = R"json({"error":{"code":-1,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; @@ -249,7 +250,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCallFail) { m_server_up->AddTool(std::make_unique("fail", "fail tool")); llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; llvm::StringLiteral response = R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; From c037d2b3e6bce7a03e84d759af18e723da4bc059 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Fri, 11 Jul 2025 15:49:27 -0700 Subject: [PATCH 11/15] [lldb] Expose debuggers and target as resources through MCP (#148075) Expose debuggers and target as resources through MCP. This has two advantages: 1. Enables returning data in a structured way. Although tools can return structured data with the latest revision of the protocol, we might not be able to update before the majority of clients has adopted it. 2. Enables the user to specify a resource themselves, rather than letting the model guess which debugger instance it should use. This PR exposes a resource for debuggers and targets. The following URI returns information about a given debugger instance: ``` lldb://debugger/ ``` For example: ``` { uri: "lldb://debugger/0" mimeType: "application/json" text: "{"debugger_id":0,"num_targets":2}" } ``` The following URI returns information about a given target: ``` lldb://debugger//target/ ``` For example: ``` { uri: "lldb://debugger/0/target/0" mimeType: "application/json" text: "{"arch":"arm64-apple-macosx26.0.0","debugger_id":0,"path":"/Users/jonas/llvm/build-ra/bin/count","target_id":0}" } ``` (cherry picked from commit 3c4c2fada26f479be7c2f9744f5b7364f7612446) --- lldb/include/lldb/Core/Debugger.h | 2 +- lldb/include/lldb/Target/Target.h | 2 +- .../Plugins/Protocol/MCP/CMakeLists.txt | 1 + lldb/source/Plugins/Protocol/MCP/MCPError.cpp | 11 + lldb/source/Plugins/Protocol/MCP/MCPError.h | 19 +- lldb/source/Plugins/Protocol/MCP/Protocol.cpp | 54 ++++- lldb/source/Plugins/Protocol/MCP/Protocol.h | 60 ++++- .../Protocol/MCP/ProtocolServerMCP.cpp | 89 ++++++- .../Plugins/Protocol/MCP/ProtocolServerMCP.h | 10 + lldb/source/Plugins/Protocol/MCP/Resource.cpp | 217 ++++++++++++++++++ lldb/source/Plugins/Protocol/MCP/Resource.h | 51 ++++ lldb/source/Plugins/Protocol/MCP/Tool.cpp | 49 +--- lldb/source/Plugins/Protocol/MCP/Tool.h | 9 - .../Protocol/ProtocolMCPServerTest.cpp | 43 +++- lldb/unittests/Protocol/ProtocolMCPTest.cpp | 98 ++++++++ 15 files changed, 645 insertions(+), 70 deletions(-) create mode 100644 lldb/source/Plugins/Protocol/MCP/Resource.cpp create mode 100644 lldb/source/Plugins/Protocol/MCP/Resource.h diff --git a/lldb/include/lldb/Core/Debugger.h b/lldb/include/lldb/Core/Debugger.h index 35a41e419c9bf..eb8aa314cf47b 100644 --- a/lldb/include/lldb/Core/Debugger.h +++ b/lldb/include/lldb/Core/Debugger.h @@ -376,7 +376,7 @@ class Debugger : public std::enable_shared_from_this, bool GetNotifyVoid() const; - const std::string &GetInstanceName() { return m_instance_name; } + const std::string &GetInstanceName() const { return m_instance_name; } bool GetShowInlineDiagnostics() const; diff --git a/lldb/include/lldb/Target/Target.h b/lldb/include/lldb/Target/Target.h index 50ebcc5a77946..79df47fec620e 100644 --- a/lldb/include/lldb/Target/Target.h +++ b/lldb/include/lldb/Target/Target.h @@ -1176,7 +1176,7 @@ class Target : public std::enable_shared_from_this, Architecture *GetArchitecturePlugin() const { return m_arch.GetPlugin(); } - Debugger &GetDebugger() { return m_debugger; } + Debugger &GetDebugger() const { return m_debugger; } size_t ReadMemoryFromFileCache(const Address &addr, void *dst, size_t dst_len, Status &error); diff --git a/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt index db31a7a69cb33..e104fb527e57a 100644 --- a/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt +++ b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt @@ -2,6 +2,7 @@ add_lldb_library(lldbPluginProtocolServerMCP PLUGIN MCPError.cpp Protocol.cpp ProtocolServerMCP.cpp + Resource.cpp Tool.cpp LINK_COMPONENTS diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.cpp b/lldb/source/Plugins/Protocol/MCP/MCPError.cpp index 5ed850066b659..659b53a14fe23 100644 --- a/lldb/source/Plugins/Protocol/MCP/MCPError.cpp +++ b/lldb/source/Plugins/Protocol/MCP/MCPError.cpp @@ -14,6 +14,7 @@ namespace lldb_private::mcp { char MCPError::ID; +char UnsupportedURI::ID; MCPError::MCPError(std::string message, int64_t error_code) : m_message(message), m_error_code(error_code) {} @@ -31,4 +32,14 @@ protocol::Error MCPError::toProtcolError() const { return error; } +UnsupportedURI::UnsupportedURI(std::string uri) : m_uri(uri) {} + +void UnsupportedURI::log(llvm::raw_ostream &OS) const { + OS << "unsupported uri: " << m_uri; +} + +std::error_code UnsupportedURI::convertToErrorCode() const { + return llvm::inconvertibleErrorCode(); +} + } // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.h b/lldb/source/Plugins/Protocol/MCP/MCPError.h index 2a76a7b087e20..f4db13d6deade 100644 --- a/lldb/source/Plugins/Protocol/MCP/MCPError.h +++ b/lldb/source/Plugins/Protocol/MCP/MCPError.h @@ -8,6 +8,7 @@ #include "Protocol.h" #include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" #include namespace lldb_private::mcp { @@ -16,7 +17,7 @@ class MCPError : public llvm::ErrorInfo { public: static char ID; - MCPError(std::string message, int64_t error_code); + MCPError(std::string message, int64_t error_code = kInternalError); void log(llvm::raw_ostream &OS) const override; std::error_code convertToErrorCode() const override; @@ -25,9 +26,25 @@ class MCPError : public llvm::ErrorInfo { protocol::Error toProtcolError() const; + static constexpr int64_t kResourceNotFound = -32002; + static constexpr int64_t kInternalError = -32603; + private: std::string m_message; int64_t m_error_code; }; +class UnsupportedURI : public llvm::ErrorInfo { +public: + static char ID; + + UnsupportedURI(std::string uri); + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + +private: + std::string m_uri; +}; + } // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.cpp b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp index d66c931a0b284..e42e1bf1118cf 100644 --- a/lldb/source/Plugins/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp @@ -107,8 +107,36 @@ bool fromJSON(const llvm::json::Value &V, ToolCapability &TC, return O && O.map("listChanged", TC.listChanged); } +llvm::json::Value toJSON(const ResourceCapability &RC) { + return llvm::json::Object{{"listChanged", RC.listChanged}, + {"subscribe", RC.subscribe}}; +} + +bool fromJSON(const llvm::json::Value &V, ResourceCapability &RC, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("listChanged", RC.listChanged) && + O.map("subscribe", RC.subscribe); +} + llvm::json::Value toJSON(const Capabilities &C) { - return llvm::json::Object{{"tools", C.tools}}; + return llvm::json::Object{{"tools", C.tools}, {"resources", C.resources}}; +} + +bool fromJSON(const llvm::json::Value &V, Resource &R, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("uri", R.uri) && O.map("name", R.name) && + O.mapOptional("description", R.description) && + O.mapOptional("mimeType", R.mimeType); +} + +llvm::json::Value toJSON(const Resource &R) { + llvm::json::Object Result{{"uri", R.uri}, {"name", R.name}}; + if (R.description) + Result.insert({"description", R.description}); + if (R.mimeType) + Result.insert({"mimeType", R.mimeType}); + return Result; } bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { @@ -116,6 +144,30 @@ bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { return O && O.map("tools", C.tools); } +llvm::json::Value toJSON(const ResourceContents &RC) { + llvm::json::Object Result{{"uri", RC.uri}, {"text", RC.text}}; + if (RC.mimeType) + Result.insert({"mimeType", RC.mimeType}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ResourceContents &RC, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("uri", RC.uri) && O.map("text", RC.text) && + O.mapOptional("mimeType", RC.mimeType); +} + +llvm::json::Value toJSON(const ResourceResult &RR) { + return llvm::json::Object{{"contents", RR.contents}}; +} + +bool fromJSON(const llvm::json::Value &V, ResourceResult &RR, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("contents", RR.contents); +} + llvm::json::Value toJSON(const TextContent &TC) { return llvm::json::Object{{"type", "text"}, {"text", TC.text}}; } diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.h b/lldb/source/Plugins/Protocol/MCP/Protocol.h index cb790dc4e5596..ffe621bee1c2a 100644 --- a/lldb/source/Plugins/Protocol/MCP/Protocol.h +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.h @@ -76,17 +76,75 @@ struct ToolCapability { llvm::json::Value toJSON(const ToolCapability &); bool fromJSON(const llvm::json::Value &, ToolCapability &, llvm::json::Path); +struct ResourceCapability { + /// Whether this server supports notifications for changes to the resources + /// list. + bool listChanged = false; + + /// Whether subscriptions are supported. + bool subscribe = false; +}; + +llvm::json::Value toJSON(const ResourceCapability &); +bool fromJSON(const llvm::json::Value &, ResourceCapability &, + llvm::json::Path); + /// Capabilities that a server may support. Known capabilities are defined here, /// in this schema, but this is not a closed set: any server can define its own, /// additional capabilities. struct Capabilities { - /// Present if the server offers any tools to call. + /// Tool capabilities of the server. ToolCapability tools; + + /// Resource capabilities of the server. + ResourceCapability resources; }; llvm::json::Value toJSON(const Capabilities &); bool fromJSON(const llvm::json::Value &, Capabilities &, llvm::json::Path); +/// A known resource that the server is capable of reading. +struct Resource { + /// The URI of this resource. + std::string uri; + + /// A human-readable name for this resource. + std::string name; + + /// A description of what this resource represents. + std::optional description; + + /// The MIME type of this resource, if known. + std::optional mimeType; +}; + +llvm::json::Value toJSON(const Resource &); +bool fromJSON(const llvm::json::Value &, Resource &, llvm::json::Path); + +/// The contents of a specific resource or sub-resource. +struct ResourceContents { + /// The URI of this resource. + std::string uri; + + /// The text of the item. This must only be set if the item can actually be + /// represented as text (not binary data). + std::string text; + + /// The MIME type of this resource, if known. + std::optional mimeType; +}; + +llvm::json::Value toJSON(const ResourceContents &); +bool fromJSON(const llvm::json::Value &, ResourceContents &, llvm::json::Path); + +/// The server's response to a resources/read request from the client. +struct ResourceResult { + std::vector contents; +}; + +llvm::json::Value toJSON(const ResourceResult &); +bool fromJSON(const llvm::json::Value &, ResourceResult &, llvm::json::Path); + /// Text provided to or from an LLM. struct TextContent { /// The text content of the message. diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index 2797cedd566e3..0d79dcdad2d65 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -28,20 +28,29 @@ ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() { AddRequestHandler("initialize", std::bind(&ProtocolServerMCP::InitializeHandler, this, std::placeholders::_1)); + AddRequestHandler("tools/list", std::bind(&ProtocolServerMCP::ToolsListHandler, this, std::placeholders::_1)); AddRequestHandler("tools/call", std::bind(&ProtocolServerMCP::ToolsCallHandler, this, std::placeholders::_1)); + + AddRequestHandler("resources/list", + std::bind(&ProtocolServerMCP::ResourcesListHandler, this, + std::placeholders::_1)); + AddRequestHandler("resources/read", + std::bind(&ProtocolServerMCP::ResourcesReadHandler, this, + std::placeholders::_1)); AddNotificationHandler( "notifications/initialized", [](const protocol::Notification &) { LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete"); }); + AddTool( std::make_unique("lldb_command", "Run an lldb command.")); - AddTool(std::make_unique( - "lldb_debugger_list", "List debugger instances with their debugger_id.")); + + AddResourceProvider(std::make_unique()); } ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } @@ -75,7 +84,7 @@ ProtocolServerMCP::Handle(protocol::Request request) { } return make_error( - llvm::formatv("no handler for request: {0}", request.method).str(), 1); + llvm::formatv("no handler for request: {0}", request.method).str()); } void ProtocolServerMCP::Handle(protocol::Notification notification) { @@ -218,7 +227,7 @@ ProtocolServerMCP::HandleData(llvm::StringRef data) { response.takeError(), [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, [&](const llvm::ErrorInfoBase &err) { - protocol_error.error.code = -1; + protocol_error.error.code = MCPError::kInternalError; protocol_error.error.message = err.message(); }); protocol_error.id = request->id; @@ -246,6 +255,9 @@ ProtocolServerMCP::HandleData(llvm::StringRef data) { protocol::Capabilities ProtocolServerMCP::GetCapabilities() { protocol::Capabilities capabilities; capabilities.tools.listChanged = true; + // FIXME: Support sending notifications when a debugger/target are + // added/removed. + capabilities.resources.listChanged = false; return capabilities; } @@ -257,6 +269,15 @@ void ProtocolServerMCP::AddTool(std::unique_ptr tool) { m_tools[tool->GetName()] = std::move(tool); } +void ProtocolServerMCP::AddResourceProvider( + std::unique_ptr resource_provider) { + std::lock_guard guard(m_server_mutex); + + if (!resource_provider) + return; + m_resource_providers.push_back(std::move(resource_provider)); +} + void ProtocolServerMCP::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { std::lock_guard guard(m_server_mutex); @@ -329,3 +350,63 @@ ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { return response; } + +llvm::Expected +ProtocolServerMCP::ResourcesListHandler(const protocol::Request &request) { + protocol::Response response; + + llvm::json::Array resources; + + std::lock_guard guard(m_server_mutex); + for (std::unique_ptr &resource_provider_up : + m_resource_providers) { + for (const protocol::Resource &resource : + resource_provider_up->GetResources()) + resources.push_back(resource); + } + response.result.emplace( + llvm::json::Object{{"resources", std::move(resources)}}); + + return response; +} + +llvm::Expected +ProtocolServerMCP::ResourcesReadHandler(const protocol::Request &request) { + protocol::Response response; + + if (!request.params) + return llvm::createStringError("no resource parameters"); + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return llvm::createStringError("no resource parameters"); + + const json::Value *uri = param_obj->get("uri"); + if (!uri) + return llvm::createStringError("no resource uri"); + + llvm::StringRef uri_str = uri->getAsString().value_or(""); + if (uri_str.empty()) + return llvm::createStringError("no resource uri"); + + std::lock_guard guard(m_server_mutex); + for (std::unique_ptr &resource_provider_up : + m_resource_providers) { + llvm::Expected result = + resource_provider_up->ReadResource(uri_str); + if (result.errorIsA()) { + llvm::consumeError(result.takeError()); + continue; + } + if (!result) + return result.takeError(); + + protocol::Response response; + response.result.emplace(std::move(*result)); + return response; + } + + return make_error( + llvm::formatv("no resource handler for uri: {0}", uri_str).str(), + MCPError::kResourceNotFound); +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index d55882cc8ab09..e273f6e2a8d37 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -10,6 +10,7 @@ #define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H #include "Protocol.h" +#include "Resource.h" #include "Tool.h" #include "lldb/Core/ProtocolServer.h" #include "lldb/Host/MainLoop.h" @@ -46,6 +47,8 @@ class ProtocolServerMCP : public ProtocolServer { std::function; void AddTool(std::unique_ptr tool); + void AddResourceProvider(std::unique_ptr resource_provider); + void AddRequestHandler(llvm::StringRef method, RequestHandler handler); void AddNotificationHandler(llvm::StringRef method, NotificationHandler handler); @@ -61,11 +64,17 @@ class ProtocolServerMCP : public ProtocolServer { llvm::Expected InitializeHandler(const protocol::Request &); + llvm::Expected ToolsListHandler(const protocol::Request &); llvm::Expected ToolsCallHandler(const protocol::Request &); + llvm::Expected + ResourcesListHandler(const protocol::Request &); + llvm::Expected + ResourcesReadHandler(const protocol::Request &); + protocol::Capabilities GetCapabilities(); llvm::StringLiteral kName = "lldb-mcp"; @@ -89,6 +98,7 @@ class ProtocolServerMCP : public ProtocolServer { std::mutex m_server_mutex; llvm::StringMap> m_tools; + std::vector> m_resource_providers; llvm::StringMap m_request_handlers; llvm::StringMap m_notification_handlers; diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.cpp b/lldb/source/Plugins/Protocol/MCP/Resource.cpp new file mode 100644 index 0000000000000..d75d5b6dd6a41 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Resource.cpp @@ -0,0 +1,217 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Resource.h" +#include "MCPError.h" +#include "lldb/Core/Debugger.h" +#include "lldb/Core/Module.h" +#include "lldb/Target/Platform.h" + +using namespace lldb_private::mcp; + +namespace { +struct DebuggerResource { + uint64_t debugger_id = 0; + std::string name; + uint64_t num_targets = 0; +}; + +llvm::json::Value toJSON(const DebuggerResource &DR) { + llvm::json::Object Result{{"debugger_id", DR.debugger_id}, + {"num_targets", DR.num_targets}}; + if (!DR.name.empty()) + Result.insert({"name", DR.name}); + return Result; +} + +struct TargetResource { + size_t debugger_id = 0; + size_t target_idx = 0; + bool selected = false; + bool dummy = false; + std::string arch; + std::string path; + std::string platform; +}; + +llvm::json::Value toJSON(const TargetResource &TR) { + llvm::json::Object Result{{"debugger_id", TR.debugger_id}, + {"target_idx", TR.target_idx}, + {"selected", TR.selected}, + {"dummy", TR.dummy}}; + if (!TR.arch.empty()) + Result.insert({"arch", TR.arch}); + if (!TR.path.empty()) + Result.insert({"path", TR.path}); + if (!TR.platform.empty()) + Result.insert({"platform", TR.platform}); + return Result; +} +} // namespace + +static constexpr llvm::StringLiteral kMimeTypeJSON = "application/json"; + +template +static llvm::Error createStringError(const char *format, Args &&...args) { + return llvm::createStringError( + llvm::formatv(format, std::forward(args)...).str()); +} + +static llvm::Error createUnsupportedURIError(llvm::StringRef uri) { + return llvm::make_error(uri.str()); +} + +protocol::Resource +DebuggerResourceProvider::GetDebuggerResource(Debugger &debugger) { + const lldb::user_id_t debugger_id = debugger.GetID(); + + protocol::Resource resource; + resource.uri = llvm::formatv("lldb://debugger/{0}", debugger_id); + resource.name = debugger.GetInstanceName(); + resource.description = + llvm::formatv("Information about debugger instance {0}: {1}", debugger_id, + debugger.GetInstanceName()); + resource.mimeType = kMimeTypeJSON; + return resource; +} + +protocol::Resource +DebuggerResourceProvider::GetTargetResource(size_t target_idx, Target &target) { + const size_t debugger_id = target.GetDebugger().GetID(); + + std::string target_name = llvm::formatv("target {0}", target_idx); + + if (Module *exe_module = target.GetExecutableModulePointer()) + target_name = exe_module->GetFileSpec().GetFilename().GetString(); + + protocol::Resource resource; + resource.uri = + llvm::formatv("lldb://debugger/{0}/target/{1}", debugger_id, target_idx); + resource.name = target_name; + resource.description = + llvm::formatv("Information about target {0} in debugger instance {1}", + target_idx, debugger_id); + resource.mimeType = kMimeTypeJSON; + return resource; +} + +std::vector DebuggerResourceProvider::GetResources() const { + std::vector resources; + + const size_t num_debuggers = Debugger::GetNumDebuggers(); + for (size_t i = 0; i < num_debuggers; ++i) { + lldb::DebuggerSP debugger_sp = Debugger::GetDebuggerAtIndex(i); + if (!debugger_sp) + continue; + resources.emplace_back(GetDebuggerResource(*debugger_sp)); + + TargetList &target_list = debugger_sp->GetTargetList(); + const size_t num_targets = target_list.GetNumTargets(); + for (size_t j = 0; j < num_targets; ++j) { + lldb::TargetSP target_sp = target_list.GetTargetAtIndex(j); + if (!target_sp) + continue; + resources.emplace_back(GetTargetResource(j, *target_sp)); + } + } + + return resources; +} + +llvm::Expected +DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { + + auto [protocol, path] = uri.split("://"); + + if (protocol != "lldb") + return createUnsupportedURIError(uri); + + llvm::SmallVector components; + path.split(components, '/'); + + if (components.size() < 2) + return createUnsupportedURIError(uri); + + if (components[0] != "debugger") + return createUnsupportedURIError(uri); + + size_t debugger_idx; + if (components[1].getAsInteger(0, debugger_idx)) + return createStringError("invalid debugger id '{0}': {1}", components[1], + path); + + if (components.size() > 3) { + if (components[2] != "target") + return createUnsupportedURIError(uri); + + size_t target_idx; + if (components[3].getAsInteger(0, target_idx)) + return createStringError("invalid target id '{0}': {1}", components[3], + path); + + return ReadTargetResource(uri, debugger_idx, target_idx); + } + + return ReadDebuggerResource(uri, debugger_idx); +} + +llvm::Expected +DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, + lldb::user_id_t debugger_id) { + lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(debugger_id); + if (!debugger_sp) + return createStringError("invalid debugger id: {0}", debugger_id); + + DebuggerResource debugger_resource; + debugger_resource.debugger_id = debugger_id; + debugger_resource.name = debugger_sp->GetInstanceName(); + debugger_resource.num_targets = debugger_sp->GetTargetList().GetNumTargets(); + + protocol::ResourceContents contents; + contents.uri = uri; + contents.mimeType = kMimeTypeJSON; + contents.text = llvm::formatv("{0}", toJSON(debugger_resource)); + + protocol::ResourceResult result; + result.contents.push_back(contents); + return result; +} + +llvm::Expected +DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, + lldb::user_id_t debugger_id, + size_t target_idx) { + + lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(debugger_id); + if (!debugger_sp) + return createStringError("invalid debugger id: {0}", debugger_id); + + TargetList &target_list = debugger_sp->GetTargetList(); + lldb::TargetSP target_sp = target_list.GetTargetAtIndex(target_idx); + if (!target_sp) + return createStringError("invalid target idx: {0}", target_idx); + + TargetResource target_resource; + target_resource.debugger_id = debugger_id; + target_resource.target_idx = target_idx; + target_resource.arch = target_sp->GetArchitecture().GetTriple().str(); + target_resource.dummy = target_sp->IsDummyTarget(); + target_resource.selected = target_sp == debugger_sp->GetSelectedTarget(); + + if (Module *exe_module = target_sp->GetExecutableModulePointer()) + target_resource.path = exe_module->GetFileSpec().GetPath(); + if (lldb::PlatformSP platform_sp = target_sp->GetPlatform()) + target_resource.platform = platform_sp->GetName(); + + protocol::ResourceContents contents; + contents.uri = uri; + contents.mimeType = kMimeTypeJSON; + contents.text = llvm::formatv("{0}", toJSON(target_resource)); + + protocol::ResourceResult result; + result.contents.push_back(contents); + return result; +} diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.h b/lldb/source/Plugins/Protocol/MCP/Resource.h new file mode 100644 index 0000000000000..5ac38e7e878ff --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Resource.h @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_RESOURCE_H +#define LLDB_PLUGINS_PROTOCOL_MCP_RESOURCE_H + +#include "Protocol.h" +#include "lldb/lldb-private.h" +#include + +namespace lldb_private::mcp { + +class ResourceProvider { +public: + ResourceProvider() = default; + virtual ~ResourceProvider() = default; + + virtual std::vector GetResources() const = 0; + virtual llvm::Expected + ReadResource(llvm::StringRef uri) const = 0; +}; + +class DebuggerResourceProvider : public ResourceProvider { +public: + using ResourceProvider::ResourceProvider; + virtual ~DebuggerResourceProvider() = default; + + virtual std::vector GetResources() const override; + virtual llvm::Expected + ReadResource(llvm::StringRef uri) const override; + +private: + static protocol::Resource GetDebuggerResource(Debugger &debugger); + static protocol::Resource GetTargetResource(size_t target_idx, + Target &target); + + static llvm::Expected + ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id); + static llvm::Expected + ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, + size_t target_idx); +}; + +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index 5c4626cf66b32..eecd56141d2ed 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -65,7 +65,7 @@ CommandTool::Call(const protocol::ToolArguments &args) { return root.getError(); lldb::DebuggerSP debugger_sp = - Debugger::GetDebuggerAtIndex(arguments.debugger_id); + Debugger::FindDebuggerWithID(arguments.debugger_id); if (!debugger_sp) return createStringError( llvm::formatv("no debugger with id {0}", arguments.debugger_id)); @@ -101,50 +101,3 @@ std::optional CommandTool::GetSchema() const { {"required", std::move(required)}}; return schema; } - -llvm::Expected -DebuggerListTool::Call(const protocol::ToolArguments &args) { - if (!std::holds_alternative(args)) - return createStringError("DebuggerListTool takes no arguments"); - - llvm::json::Path::Root root; - - // Return a nested Markdown list with debuggers and target. - // Example output: - // - // - debugger 0 - // - target 0 /path/to/foo - // - target 1 - // - debugger 1 - // - target 0 /path/to/bar - // - // FIXME: Use Structured Content when we adopt protocol version 2025-06-18. - std::string output; - llvm::raw_string_ostream os(output); - - const size_t num_debuggers = Debugger::GetNumDebuggers(); - for (size_t i = 0; i < num_debuggers; ++i) { - lldb::DebuggerSP debugger_sp = Debugger::GetDebuggerAtIndex(i); - if (!debugger_sp) - continue; - - os << "- debugger " << i << '\n'; - - TargetList &target_list = debugger_sp->GetTargetList(); - const size_t num_targets = target_list.GetNumTargets(); - for (size_t j = 0; j < num_targets; ++j) { - lldb::TargetSP target_sp = target_list.GetTargetAtIndex(j); - if (!target_sp) - continue; - os << " - target " << j; - if (target_sp == target_list.GetSelectedTarget()) - os << " (selected)"; - // Append the module path if we have one. - if (Module *exe_module = target_sp->GetExecutableModulePointer()) - os << " " << exe_module->GetFileSpec().GetPath(); - os << '\n'; - } - } - - return createTextResult(output); -} diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h index 74ab04b472522..d0f639adad24e 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.h +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -48,15 +48,6 @@ class CommandTool : public mcp::Tool { virtual std::optional GetSchema() const override; }; -class DebuggerListTool : public mcp::Tool { -public: - using mcp::Tool::Tool; - ~DebuggerListTool() = default; - - virtual llvm::Expected - Call(const protocol::ToolArguments &args) override; -}; - } // namespace lldb_private::mcp #endif diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index 3aa144ed4e19f..b2dcc740b5efd 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "Plugins/Platform/MacOSX/PlatformRemoteMacOSX.h" +#include "Plugins/Protocol/MCP/MCPError.h" #include "Plugins/Protocol/MCP/ProtocolServerMCP.h" #include "TestingSupport/Host/SocketTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" @@ -28,6 +29,7 @@ class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { public: using ProtocolServerMCP::AddNotificationHandler; using ProtocolServerMCP::AddRequestHandler; + using ProtocolServerMCP::AddResourceProvider; using ProtocolServerMCP::AddTool; using ProtocolServerMCP::GetSocket; using ProtocolServerMCP::ProtocolServerMCP; @@ -61,6 +63,38 @@ class TestTool : public mcp::Tool { } }; +class TestResourceProvider : public mcp::ResourceProvider { + using mcp::ResourceProvider::ResourceProvider; + + virtual std::vector GetResources() const override { + std::vector resources; + + Resource resource; + resource.uri = "lldb://foo/bar"; + resource.name = "name"; + resource.description = "description"; + resource.mimeType = "application/json"; + + resources.push_back(resource); + return resources; + } + + virtual llvm::Expected + ReadResource(llvm::StringRef uri) const override { + if (uri != "lldb://foo/bar") + return llvm::make_error(uri.str()); + + ResourceContents contents; + contents.uri = "lldb://foo/bar"; + contents.mimeType = "application/json"; + contents.text = "foobar"; + + ResourceResult result; + result.contents.push_back(contents); + return result; + } +}; + /// Test tool that returns an error. class ErrorTool : public mcp::Tool { public: @@ -118,6 +152,7 @@ class ProtocolServerMCPTest : public ::testing::Test { connection.name = llvm::formatv("{0}:0", k_localhost).str(); m_server_up = std::make_unique(); m_server_up->AddTool(std::make_unique("test", "test tool")); + m_server_up->AddResourceProvider(std::make_unique()); ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); // Connect to the server over a TCP socket. @@ -148,7 +183,7 @@ TEST_F(ProtocolServerMCPTest, Intialization) { llvm::StringLiteral request = R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; llvm::StringLiteral response = - R"json({"jsonrpc":"2.0","id":0,"result":{"capabilities":{"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; + R"json( {"id":0,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); @@ -168,7 +203,7 @@ TEST_F(ProtocolServerMCPTest, ToolsList) { llvm::StringLiteral request = R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; llvm::StringLiteral response = - R"json( {"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","inputSchema":{"type":"object"},"name":"test"},{"description":"List debugger instances with their debugger_id.","inputSchema":{"type":"object"},"name":"lldb_debugger_list"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"},"debugger_id":{"type":"number"}},"required":["debugger_id"],"type":"object"},"name":"lldb_command"}]}})json"; + R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","inputSchema":{"type":"object"},"name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"},"debugger_id":{"type":"number"}},"required":["debugger_id"],"type":"object"},"name":"lldb_command"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); @@ -188,7 +223,7 @@ TEST_F(ProtocolServerMCPTest, ResourcesList) { llvm::StringLiteral request = R"json({"method":"resources/list","params":{},"jsonrpc":"2.0","id":2})json"; llvm::StringLiteral response = - R"json({"error":{"code":1,"message":"no handler for request: resources/list"},"id":2,"jsonrpc":"2.0"})json"; + R"json({"id":2,"jsonrpc":"2.0","result":{"resources":[{"description":"description","mimeType":"application/json","name":"name","uri":"lldb://foo/bar"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); @@ -230,7 +265,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { llvm::StringLiteral request = R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; llvm::StringLiteral response = - R"json({"error":{"code":-1,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; + R"json({"error":{"code":-32603,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp index 14cc240dd3628..ddc5a411a5c31 100644 --- a/lldb/unittests/Protocol/ProtocolMCPTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -230,3 +230,101 @@ TEST(ProtocolMCPTest, ResponseWithError) { EXPECT_EQ(response.error->code, deserialized_response->error->code); EXPECT_EQ(response.error->message, deserialized_response->error->message); } + +TEST(ProtocolMCPTest, Resource) { + Resource resource; + resource.uri = "resource://example/test"; + resource.name = "Test Resource"; + resource.description = "A test resource for unit testing"; + resource.mimeType = "text/plain"; + + llvm::Expected deserialized_resource = roundtripJSON(resource); + ASSERT_THAT_EXPECTED(deserialized_resource, llvm::Succeeded()); + + EXPECT_EQ(resource.uri, deserialized_resource->uri); + EXPECT_EQ(resource.name, deserialized_resource->name); + EXPECT_EQ(resource.description, deserialized_resource->description); + EXPECT_EQ(resource.mimeType, deserialized_resource->mimeType); +} + +TEST(ProtocolMCPTest, ResourceWithoutOptionals) { + Resource resource; + resource.uri = "resource://example/minimal"; + resource.name = "Minimal Resource"; + + llvm::Expected deserialized_resource = roundtripJSON(resource); + ASSERT_THAT_EXPECTED(deserialized_resource, llvm::Succeeded()); + + EXPECT_EQ(resource.uri, deserialized_resource->uri); + EXPECT_EQ(resource.name, deserialized_resource->name); + EXPECT_FALSE(deserialized_resource->description.has_value()); + EXPECT_FALSE(deserialized_resource->mimeType.has_value()); +} + +TEST(ProtocolMCPTest, ResourceContents) { + ResourceContents contents; + contents.uri = "resource://example/content"; + contents.text = "This is the content of the resource"; + contents.mimeType = "text/plain"; + + llvm::Expected deserialized_contents = + roundtripJSON(contents); + ASSERT_THAT_EXPECTED(deserialized_contents, llvm::Succeeded()); + + EXPECT_EQ(contents.uri, deserialized_contents->uri); + EXPECT_EQ(contents.text, deserialized_contents->text); + EXPECT_EQ(contents.mimeType, deserialized_contents->mimeType); +} + +TEST(ProtocolMCPTest, ResourceContentsWithoutMimeType) { + ResourceContents contents; + contents.uri = "resource://example/content-no-mime"; + contents.text = "Content without mime type specified"; + + llvm::Expected deserialized_contents = + roundtripJSON(contents); + ASSERT_THAT_EXPECTED(deserialized_contents, llvm::Succeeded()); + + EXPECT_EQ(contents.uri, deserialized_contents->uri); + EXPECT_EQ(contents.text, deserialized_contents->text); + EXPECT_FALSE(deserialized_contents->mimeType.has_value()); +} + +TEST(ProtocolMCPTest, ResourceResult) { + ResourceContents contents1; + contents1.uri = "resource://example/content1"; + contents1.text = "First resource content"; + contents1.mimeType = "text/plain"; + + ResourceContents contents2; + contents2.uri = "resource://example/content2"; + contents2.text = "Second resource content"; + contents2.mimeType = "application/json"; + + ResourceResult result; + result.contents = {contents1, contents2}; + + llvm::Expected deserialized_result = roundtripJSON(result); + ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded()); + + ASSERT_EQ(result.contents.size(), deserialized_result->contents.size()); + + EXPECT_EQ(result.contents[0].uri, deserialized_result->contents[0].uri); + EXPECT_EQ(result.contents[0].text, deserialized_result->contents[0].text); + EXPECT_EQ(result.contents[0].mimeType, + deserialized_result->contents[0].mimeType); + + EXPECT_EQ(result.contents[1].uri, deserialized_result->contents[1].uri); + EXPECT_EQ(result.contents[1].text, deserialized_result->contents[1].text); + EXPECT_EQ(result.contents[1].mimeType, + deserialized_result->contents[1].mimeType); +} + +TEST(ProtocolMCPTest, ResourceResultEmpty) { + ResourceResult result; + + llvm::Expected deserialized_result = roundtripJSON(result); + ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded()); + + EXPECT_TRUE(deserialized_result->contents.empty()); +} From 502d2a1ca08c3195458c7ed6ae21a3ab529b5e43 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Fri, 11 Jul 2025 17:22:22 -0700 Subject: [PATCH 12/15] [lldb] Simplify handling of empty strings for MCP (NFC) (#148317) Instead of storing a `std::optional`, directly use a `std::string` and treat a missing value the same was as an empty string. (cherry picked from commit 6fea3da40447514102118f2aeece590af0e16e5c) --- lldb/source/Plugins/Protocol/MCP/Protocol.cpp | 10 +++++----- lldb/source/Plugins/Protocol/MCP/Protocol.h | 10 +++++----- lldb/source/Plugins/Protocol/MCP/Tool.cpp | 2 +- lldb/unittests/Protocol/ProtocolMCPTest.cpp | 6 +++--- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.cpp b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp index e42e1bf1118cf..274ba6fac01ec 100644 --- a/lldb/source/Plugins/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp @@ -42,7 +42,7 @@ bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) { llvm::json::Value toJSON(const ErrorInfo &EI) { llvm::json::Object Result{{"code", EI.code}, {"message", EI.message}}; - if (EI.data) + if (!EI.data.empty()) Result.insert({"data", EI.data}); return Result; } @@ -132,9 +132,9 @@ bool fromJSON(const llvm::json::Value &V, Resource &R, llvm::json::Path P) { llvm::json::Value toJSON(const Resource &R) { llvm::json::Object Result{{"uri", R.uri}, {"name", R.name}}; - if (R.description) + if (!R.description.empty()) Result.insert({"description", R.description}); - if (R.mimeType) + if (!R.mimeType.empty()) Result.insert({"mimeType", R.mimeType}); return Result; } @@ -146,7 +146,7 @@ bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { llvm::json::Value toJSON(const ResourceContents &RC) { llvm::json::Object Result{{"uri", RC.uri}, {"text", RC.text}}; - if (RC.mimeType) + if (!RC.mimeType.empty()) Result.insert({"mimeType", RC.mimeType}); return Result; } @@ -188,7 +188,7 @@ bool fromJSON(const llvm::json::Value &V, TextResult &TR, llvm::json::Path P) { llvm::json::Value toJSON(const ToolDefinition &TD) { llvm::json::Object Result{{"name", TD.name}}; - if (TD.description) + if (!TD.description.empty()) Result.insert({"description", TD.description}); if (TD.inputSchema) Result.insert({"inputSchema", TD.inputSchema}); diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.h b/lldb/source/Plugins/Protocol/MCP/Protocol.h index ffe621bee1c2a..ce74836e62541 100644 --- a/lldb/source/Plugins/Protocol/MCP/Protocol.h +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.h @@ -36,7 +36,7 @@ bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); struct ErrorInfo { int64_t code = 0; std::string message; - std::optional data; + std::string data; }; llvm::json::Value toJSON(const ErrorInfo &); @@ -112,10 +112,10 @@ struct Resource { std::string name; /// A description of what this resource represents. - std::optional description; + std::string description; /// The MIME type of this resource, if known. - std::optional mimeType; + std::string mimeType; }; llvm::json::Value toJSON(const Resource &); @@ -131,7 +131,7 @@ struct ResourceContents { std::string text; /// The MIME type of this resource, if known. - std::optional mimeType; + std::string mimeType; }; llvm::json::Value toJSON(const ResourceContents &); @@ -167,7 +167,7 @@ struct ToolDefinition { std::string name; /// Human-readable description. - std::optional description; + std::string description; // JSON Schema for the tool's parameters. std::optional inputSchema; diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index eecd56141d2ed..bbc19a1e51942 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -45,7 +45,7 @@ Tool::Tool(std::string name, std::string description) protocol::ToolDefinition Tool::GetDefinition() const { protocol::ToolDefinition definition; definition.name = m_name; - definition.description.emplace(m_description); + definition.description = m_description; if (std::optional input_schema = GetSchema()) definition.inputSchema = *input_schema; diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp index ddc5a411a5c31..ce8120cbfe9b9 100644 --- a/lldb/unittests/Protocol/ProtocolMCPTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -257,8 +257,8 @@ TEST(ProtocolMCPTest, ResourceWithoutOptionals) { EXPECT_EQ(resource.uri, deserialized_resource->uri); EXPECT_EQ(resource.name, deserialized_resource->name); - EXPECT_FALSE(deserialized_resource->description.has_value()); - EXPECT_FALSE(deserialized_resource->mimeType.has_value()); + EXPECT_TRUE(deserialized_resource->description.empty()); + EXPECT_TRUE(deserialized_resource->mimeType.empty()); } TEST(ProtocolMCPTest, ResourceContents) { @@ -287,7 +287,7 @@ TEST(ProtocolMCPTest, ResourceContentsWithoutMimeType) { EXPECT_EQ(contents.uri, deserialized_contents->uri); EXPECT_EQ(contents.text, deserialized_contents->text); - EXPECT_FALSE(deserialized_contents->mimeType.has_value()); + EXPECT_TRUE(deserialized_contents->mimeType.empty()); } TEST(ProtocolMCPTest, ResourceResult) { From c2edeb5caf5a4ebe7b64887fcfef412b0eb1f69b Mon Sep 17 00:00:00 2001 From: Pavel Labath Date: Thu, 27 Feb 2025 11:15:59 +0100 Subject: [PATCH 13/15] [lldb] Assorted improvements to the Pipe class (#128719) The main motivation for this was the inconsistency in handling of partial reads/writes between the windows and posix implementations (windows was returning partial reads, posix was trying to fill the buffer completely). I settle on the windows implementation, as that's the more common behavior, and the "eager" version can be implemented on top of that (in most cases, it isn't necessary, since we're writing just a single byte). Since this also required auditing the callers to make sure they're handling partial reads/writes correctly, I used the opportunity to modernize the function signatures as a forcing function. They now use the `Timeout` class (basically an `optional`) to support both polls (timeout=0) and blocking (timeout=nullopt) operations in a single function, and use an `Expected` instead of a by-ref result to return the number of bytes read/written. As a drive-by, I also fix a problem with the windows implementation where we were rounding the timeout value down, which meant that calls could time out slightly sooner than expected. (cherry picked from commit c0b5451129bba52e33cd7957d58af897a58d14c6) --- lldb/include/lldb/Host/PipeBase.h | 27 ++--- lldb/include/lldb/Host/posix/PipePosix.h | 19 +-- lldb/include/lldb/Host/windows/PipeWindows.h | 18 +-- lldb/source/Host/common/PipeBase.cpp | 16 --- lldb/source/Host/common/Socket.cpp | 31 +++-- .../posix/ConnectionFileDescriptorPosix.cpp | 16 +-- lldb/source/Host/posix/MainLoopPosix.cpp | 6 +- lldb/source/Host/posix/PipePosix.cpp | 107 ++++++++--------- lldb/source/Host/windows/PipeWindows.cpp | 87 ++++++-------- .../gdb-remote/GDBRemoteCommunication.cpp | 24 +++- lldb/source/Target/Process.cpp | 17 +-- lldb/tools/lldb-server/lldb-gdbserver.cpp | 28 +++-- lldb/unittests/Host/PipeTest.cpp | 111 +++++++++++++----- 13 files changed, 260 insertions(+), 247 deletions(-) diff --git a/lldb/include/lldb/Host/PipeBase.h b/lldb/include/lldb/Host/PipeBase.h index d51d0cd54e036..ed8df6bf1e511 100644 --- a/lldb/include/lldb/Host/PipeBase.h +++ b/lldb/include/lldb/Host/PipeBase.h @@ -10,12 +10,11 @@ #ifndef LLDB_HOST_PIPEBASE_H #define LLDB_HOST_PIPEBASE_H -#include -#include - #include "lldb/Utility/Status.h" +#include "lldb/Utility/Timeout.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" namespace lldb_private { class PipeBase { @@ -32,10 +31,9 @@ class PipeBase { virtual Status OpenAsReader(llvm::StringRef name, bool child_process_inherit) = 0; - Status OpenAsWriter(llvm::StringRef name, bool child_process_inherit); - virtual Status - OpenAsWriterWithTimeout(llvm::StringRef name, bool child_process_inherit, - const std::chrono::microseconds &timeout) = 0; + virtual llvm::Error OpenAsWriter(llvm::StringRef name, + bool child_process_inherit, + const Timeout &timeout) = 0; virtual bool CanRead() const = 0; virtual bool CanWrite() const = 0; @@ -56,14 +54,13 @@ class PipeBase { // Delete named pipe. virtual Status Delete(llvm::StringRef name) = 0; - virtual Status WriteWithTimeout(const void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_written) = 0; - Status Write(const void *buf, size_t size, size_t &bytes_written); - virtual Status ReadWithTimeout(void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_read) = 0; - Status Read(void *buf, size_t size, size_t &bytes_read); + virtual llvm::Expected + Write(const void *buf, size_t size, + const Timeout &timeout = std::nullopt) = 0; + + virtual llvm::Expected + Read(void *buf, size_t size, + const Timeout &timeout = std::nullopt) = 0; }; } diff --git a/lldb/include/lldb/Host/posix/PipePosix.h b/lldb/include/lldb/Host/posix/PipePosix.h index 2e291160817c4..effd33fba7eb0 100644 --- a/lldb/include/lldb/Host/posix/PipePosix.h +++ b/lldb/include/lldb/Host/posix/PipePosix.h @@ -8,6 +8,7 @@ #ifndef LLDB_HOST_POSIX_PIPEPOSIX_H #define LLDB_HOST_POSIX_PIPEPOSIX_H + #include "lldb/Host/PipeBase.h" #include @@ -38,9 +39,8 @@ class PipePosix : public PipeBase { llvm::SmallVectorImpl &name) override; Status OpenAsReader(llvm::StringRef name, bool child_process_inherit) override; - Status - OpenAsWriterWithTimeout(llvm::StringRef name, bool child_process_inherit, - const std::chrono::microseconds &timeout) override; + llvm::Error OpenAsWriter(llvm::StringRef name, bool child_process_inherit, + const Timeout &timeout) override; bool CanRead() const override; bool CanWrite() const override; @@ -64,12 +64,13 @@ class PipePosix : public PipeBase { Status Delete(llvm::StringRef name) override; - Status WriteWithTimeout(const void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_written) override; - Status ReadWithTimeout(void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_read) override; + llvm::Expected + Write(const void *buf, size_t size, + const Timeout &timeout = std::nullopt) override; + + llvm::Expected + Read(void *buf, size_t size, + const Timeout &timeout = std::nullopt) override; private: bool CanReadUnlocked() const; diff --git a/lldb/include/lldb/Host/windows/PipeWindows.h b/lldb/include/lldb/Host/windows/PipeWindows.h index e28d104cc60ec..9cf591a2d4629 100644 --- a/lldb/include/lldb/Host/windows/PipeWindows.h +++ b/lldb/include/lldb/Host/windows/PipeWindows.h @@ -38,9 +38,8 @@ class PipeWindows : public PipeBase { llvm::SmallVectorImpl &name) override; Status OpenAsReader(llvm::StringRef name, bool child_process_inherit) override; - Status - OpenAsWriterWithTimeout(llvm::StringRef name, bool child_process_inherit, - const std::chrono::microseconds &timeout) override; + llvm::Error OpenAsWriter(llvm::StringRef name, bool child_process_inherit, + const Timeout &timeout) override; bool CanRead() const override; bool CanWrite() const override; @@ -59,12 +58,13 @@ class PipeWindows : public PipeBase { Status Delete(llvm::StringRef name) override; - Status WriteWithTimeout(const void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_written) override; - Status ReadWithTimeout(void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_read) override; + llvm::Expected + Write(const void *buf, size_t size, + const Timeout &timeout = std::nullopt) override; + + llvm::Expected + Read(void *buf, size_t size, + const Timeout &timeout = std::nullopt) override; // PipeWindows specific methods. These allow access to the underlying OS // handle. diff --git a/lldb/source/Host/common/PipeBase.cpp b/lldb/source/Host/common/PipeBase.cpp index 904a2df12392d..400990f4e41b9 100644 --- a/lldb/source/Host/common/PipeBase.cpp +++ b/lldb/source/Host/common/PipeBase.cpp @@ -11,19 +11,3 @@ using namespace lldb_private; PipeBase::~PipeBase() = default; - -Status PipeBase::OpenAsWriter(llvm::StringRef name, - bool child_process_inherit) { - return OpenAsWriterWithTimeout(name, child_process_inherit, - std::chrono::microseconds::zero()); -} - -Status PipeBase::Write(const void *buf, size_t size, size_t &bytes_written) { - return WriteWithTimeout(buf, size, std::chrono::microseconds::zero(), - bytes_written); -} - -Status PipeBase::Read(void *buf, size_t size, size_t &bytes_read) { - return ReadWithTimeout(buf, size, std::chrono::microseconds::zero(), - bytes_read); -} diff --git a/lldb/source/Host/common/Socket.cpp b/lldb/source/Host/common/Socket.cpp index 30a356f034803..77b80da2cb5ea 100644 --- a/lldb/source/Host/common/Socket.cpp +++ b/lldb/source/Host/common/Socket.cpp @@ -103,15 +103,14 @@ Status SharedSocket::CompleteSending(lldb::pid_t child_pid) { "WSADuplicateSocket() failed, error: %d", last_error); } - size_t num_bytes; - Status error = - m_socket_pipe.WriteWithTimeout(&protocol_info, sizeof(protocol_info), - std::chrono::seconds(10), num_bytes); - if (error.Fail()) - return error; - if (num_bytes != sizeof(protocol_info)) + llvm::Expected num_bytes = m_socket_pipe.Write( + &protocol_info, sizeof(protocol_info), std::chrono::seconds(10)); + if (!num_bytes) + return Status::FromError(num_bytes.takeError()); + if (*num_bytes != sizeof(protocol_info)) return Status::FromErrorStringWithFormatv( - "WriteWithTimeout(WSAPROTOCOL_INFO) failed: {0} bytes", num_bytes); + "Write(WSAPROTOCOL_INFO) failed: wrote {0}/{1} bytes", *num_bytes, + sizeof(protocol_info)); #endif return Status(); } @@ -123,16 +122,14 @@ Status SharedSocket::GetNativeSocket(shared_fd_t fd, NativeSocket &socket) { WSAPROTOCOL_INFO protocol_info; { Pipe socket_pipe(fd, LLDB_INVALID_PIPE); - size_t num_bytes; - Status error = - socket_pipe.ReadWithTimeout(&protocol_info, sizeof(protocol_info), - std::chrono::seconds(10), num_bytes); - if (error.Fail()) - return error; - if (num_bytes != sizeof(protocol_info)) { + llvm::Expected num_bytes = socket_pipe.Read( + &protocol_info, sizeof(protocol_info), std::chrono::seconds(10)); + if (!num_bytes) + return Status::FromError(num_bytes.takeError()); + if (*num_bytes != sizeof(protocol_info)) { return Status::FromErrorStringWithFormatv( - "socket_pipe.ReadWithTimeout(WSAPROTOCOL_INFO) failed: {0} bytes", - num_bytes); + "Read(WSAPROTOCOL_INFO) failed: read {0}/{1} bytes", *num_bytes, + sizeof(protocol_info)); } } socket = ::WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, diff --git a/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp b/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp index d0cc68826d4bb..e0173e90515c5 100644 --- a/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp +++ b/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp @@ -183,9 +183,7 @@ ConnectionFileDescriptor::Connect(llvm::StringRef path, } bool ConnectionFileDescriptor::InterruptRead() { - size_t bytes_written = 0; - Status result = m_pipe.Write("i", 1, bytes_written); - return result.Success(); + return !errorToBool(m_pipe.Write("i", 1).takeError()); } ConnectionStatus ConnectionFileDescriptor::Disconnect(Status *error_ptr) { @@ -210,13 +208,11 @@ ConnectionStatus ConnectionFileDescriptor::Disconnect(Status *error_ptr) { std::unique_lock locker(m_mutex, std::defer_lock); if (!locker.try_lock()) { if (m_pipe.CanWrite()) { - size_t bytes_written = 0; - Status result = m_pipe.Write("q", 1, bytes_written); - LLDB_LOGF(log, - "%p ConnectionFileDescriptor::Disconnect(): Couldn't get " - "the lock, sent 'q' to %d, error = '%s'.", - static_cast(this), m_pipe.GetWriteFileDescriptor(), - result.AsCString()); + llvm::Error err = m_pipe.Write("q", 1).takeError(); + LLDB_LOG(log, + "{0}: Couldn't get the lock, sent 'q' to {1}, error = '{2}'.", + this, m_pipe.GetWriteFileDescriptor(), err); + consumeError(std::move(err)); } else if (log) { LLDB_LOGF(log, "%p ConnectionFileDescriptor::Disconnect(): Couldn't get the " diff --git a/lldb/source/Host/posix/MainLoopPosix.cpp b/lldb/source/Host/posix/MainLoopPosix.cpp index 816581e70294a..3106f6e7c0e11 100644 --- a/lldb/source/Host/posix/MainLoopPosix.cpp +++ b/lldb/source/Host/posix/MainLoopPosix.cpp @@ -404,9 +404,5 @@ void MainLoopPosix::TriggerPendingCallbacks() { return; char c = '.'; - size_t bytes_written; - Status error = m_trigger_pipe.Write(&c, 1, bytes_written); - assert(error.Success()); - UNUSED_IF_ASSERT_DISABLED(error); - assert(bytes_written == 1); + cantFail(m_trigger_pipe.Write(&c, 1)); } diff --git a/lldb/source/Host/posix/PipePosix.cpp b/lldb/source/Host/posix/PipePosix.cpp index 24c563d8c24bd..a8c4f8df333a4 100644 --- a/lldb/source/Host/posix/PipePosix.cpp +++ b/lldb/source/Host/posix/PipePosix.cpp @@ -12,7 +12,9 @@ #include "lldb/Utility/SelectHelper.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/Errno.h" +#include "llvm/Support/Error.h" #include +#include #include #include @@ -164,26 +166,27 @@ Status PipePosix::OpenAsReader(llvm::StringRef name, return error; } -Status -PipePosix::OpenAsWriterWithTimeout(llvm::StringRef name, - bool child_process_inherit, - const std::chrono::microseconds &timeout) { +llvm::Error PipePosix::OpenAsWriter(llvm::StringRef name, + bool child_process_inherit, + const Timeout &timeout) { std::lock_guard guard(m_write_mutex); if (CanReadUnlocked() || CanWriteUnlocked()) - return Status::FromErrorString("Pipe is already opened"); + return llvm::createStringError("Pipe is already opened"); int flags = O_WRONLY | O_NONBLOCK; if (!child_process_inherit) flags |= O_CLOEXEC; using namespace std::chrono; - const auto finish_time = Now() + timeout; + std::optional> finish_time; + if (timeout) + finish_time = Now() + *timeout; while (!CanWriteUnlocked()) { - if (timeout != microseconds::zero()) { - const auto dur = duration_cast(finish_time - Now()).count(); - if (dur <= 0) - return Status::FromErrorString( + if (timeout) { + if (Now() > finish_time) + return llvm::createStringError( + std::make_error_code(std::errc::timed_out), "timeout exceeded - reader hasn't opened so far"); } @@ -193,7 +196,8 @@ PipePosix::OpenAsWriterWithTimeout(llvm::StringRef name, const auto errno_copy = errno; // We may get ENXIO if a reader side of the pipe hasn't opened yet. if (errno_copy != ENXIO && errno_copy != EINTR) - return Status(errno_copy, eErrorTypePOSIX); + return llvm::errorCodeToError( + std::error_code(errno_copy, std::generic_category())); std::this_thread::sleep_for( milliseconds(OPEN_WRITER_SLEEP_TIMEOUT_MSECS)); @@ -202,7 +206,7 @@ PipePosix::OpenAsWriterWithTimeout(llvm::StringRef name, } } - return Status(); + return llvm::Error::success(); } int PipePosix::GetReadFileDescriptor() const { @@ -300,70 +304,51 @@ void PipePosix::CloseWriteFileDescriptorUnlocked() { } } -Status PipePosix::ReadWithTimeout(void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_read) { +llvm::Expected PipePosix::Read(void *buf, size_t size, + const Timeout &timeout) { std::lock_guard guard(m_read_mutex); - bytes_read = 0; if (!CanReadUnlocked()) - return Status(EINVAL, eErrorTypePOSIX); + return llvm::errorCodeToError( + std::make_error_code(std::errc::invalid_argument)); const int fd = GetReadFileDescriptorUnlocked(); SelectHelper select_helper; - select_helper.SetTimeout(timeout); + if (timeout) + select_helper.SetTimeout(*timeout); select_helper.FDSetRead(fd); - Status error; - while (error.Success()) { - error = select_helper.Select(); - if (error.Success()) { - auto result = - ::read(fd, static_cast(buf) + bytes_read, size - bytes_read); - if (result != -1) { - bytes_read += result; - if (bytes_read == size || result == 0) - break; - } else if (errno == EINTR) { - continue; - } else { - error = Status::FromErrno(); - break; - } - } - } - return error; + if (llvm::Error error = select_helper.Select().takeError()) + return error; + + ssize_t result = ::read(fd, buf, size); + if (result == -1) + return llvm::errorCodeToError( + std::error_code(errno, std::generic_category())); + + return result; } -Status PipePosix::WriteWithTimeout(const void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_written) { +llvm::Expected PipePosix::Write(const void *buf, size_t size, + const Timeout &timeout) { std::lock_guard guard(m_write_mutex); - bytes_written = 0; if (!CanWriteUnlocked()) - return Status(EINVAL, eErrorTypePOSIX); + return llvm::errorCodeToError( + std::make_error_code(std::errc::invalid_argument)); const int fd = GetWriteFileDescriptorUnlocked(); SelectHelper select_helper; - select_helper.SetTimeout(timeout); + if (timeout) + select_helper.SetTimeout(*timeout); select_helper.FDSetWrite(fd); - Status error; - while (error.Success()) { - error = select_helper.Select(); - if (error.Success()) { - auto result = ::write(fd, static_cast(buf) + bytes_written, - size - bytes_written); - if (result != -1) { - bytes_written += result; - if (bytes_written == size) - break; - } else if (errno == EINTR) { - continue; - } else { - error = Status::FromErrno(); - } - } - } - return error; + if (llvm::Error error = select_helper.Select().takeError()) + return error; + + ssize_t result = ::write(fd, buf, size); + if (result == -1) + return llvm::errorCodeToError( + std::error_code(errno, std::generic_category())); + + return result; } diff --git a/lldb/source/Host/windows/PipeWindows.cpp b/lldb/source/Host/windows/PipeWindows.cpp index d79dc3c2f82c9..a13929b65e087 100644 --- a/lldb/source/Host/windows/PipeWindows.cpp +++ b/lldb/source/Host/windows/PipeWindows.cpp @@ -151,14 +151,13 @@ Status PipeWindows::OpenAsReader(llvm::StringRef name, return OpenNamedPipe(name, child_process_inherit, true); } -Status -PipeWindows::OpenAsWriterWithTimeout(llvm::StringRef name, - bool child_process_inherit, - const std::chrono::microseconds &timeout) { +llvm::Error PipeWindows::OpenAsWriter(llvm::StringRef name, + bool child_process_inherit, + const Timeout &timeout) { if (CanWrite()) - return Status(); // Note the name is ignored. + return llvm::Error::success(); // Note the name is ignored. - return OpenNamedPipe(name, child_process_inherit, false); + return OpenNamedPipe(name, child_process_inherit, false).takeError(); } Status PipeWindows::OpenNamedPipe(llvm::StringRef name, @@ -270,29 +269,24 @@ PipeWindows::GetReadNativeHandle() { return m_read; } HANDLE PipeWindows::GetWriteNativeHandle() { return m_write; } -Status PipeWindows::ReadWithTimeout(void *buf, size_t size, - const std::chrono::microseconds &duration, - size_t &bytes_read) { +llvm::Expected PipeWindows::Read(void *buf, size_t size, + const Timeout &timeout) { if (!CanRead()) - return Status(ERROR_INVALID_HANDLE, eErrorTypeWin32); + return Status(ERROR_INVALID_HANDLE, eErrorTypeWin32).takeError(); - bytes_read = 0; - DWORD sys_bytes_read = 0; - BOOL result = - ::ReadFile(m_read, buf, size, &sys_bytes_read, &m_read_overlapped); - if (result) { - bytes_read = sys_bytes_read; - return Status(); - } + DWORD bytes_read = 0; + BOOL result = ::ReadFile(m_read, buf, size, &bytes_read, &m_read_overlapped); + if (result) + return bytes_read; DWORD failure_error = ::GetLastError(); if (failure_error != ERROR_IO_PENDING) - return Status(failure_error, eErrorTypeWin32); + return Status(failure_error, eErrorTypeWin32).takeError(); - DWORD timeout = (duration == std::chrono::microseconds::zero()) - ? INFINITE - : duration.count() / 1000; - DWORD wait_result = ::WaitForSingleObject(m_read_overlapped.hEvent, timeout); + DWORD timeout_msec = + timeout ? ceil(*timeout).count() : INFINITE; + DWORD wait_result = + ::WaitForSingleObject(m_read_overlapped.hEvent, timeout_msec); if (wait_result != WAIT_OBJECT_0) { // The operation probably failed. However, if it timed out, we need to // cancel the I/O. Between the time we returned from WaitForSingleObject @@ -308,42 +302,36 @@ Status PipeWindows::ReadWithTimeout(void *buf, size_t size, failed = false; } if (failed) - return Status(failure_error, eErrorTypeWin32); + return Status(failure_error, eErrorTypeWin32).takeError(); } // Now we call GetOverlappedResult setting bWait to false, since we've // already waited as long as we're willing to. - if (!::GetOverlappedResult(m_read, &m_read_overlapped, &sys_bytes_read, - FALSE)) - return Status(::GetLastError(), eErrorTypeWin32); + if (!::GetOverlappedResult(m_read, &m_read_overlapped, &bytes_read, FALSE)) + return Status(::GetLastError(), eErrorTypeWin32).takeError(); - bytes_read = sys_bytes_read; - return Status(); + return bytes_read; } -Status PipeWindows::WriteWithTimeout(const void *buf, size_t size, - const std::chrono::microseconds &duration, - size_t &bytes_written) { +llvm::Expected PipeWindows::Write(const void *buf, size_t size, + const Timeout &timeout) { if (!CanWrite()) - return Status(ERROR_INVALID_HANDLE, eErrorTypeWin32); + return Status(ERROR_INVALID_HANDLE, eErrorTypeWin32).takeError(); - bytes_written = 0; - DWORD sys_bytes_write = 0; + DWORD bytes_written = 0; BOOL result = - ::WriteFile(m_write, buf, size, &sys_bytes_write, &m_write_overlapped); - if (result) { - bytes_written = sys_bytes_write; - return Status(); - } + ::WriteFile(m_write, buf, size, &bytes_written, &m_write_overlapped); + if (result) + return bytes_written; DWORD failure_error = ::GetLastError(); if (failure_error != ERROR_IO_PENDING) - return Status(failure_error, eErrorTypeWin32); + return Status(failure_error, eErrorTypeWin32).takeError(); - DWORD timeout = (duration == std::chrono::microseconds::zero()) - ? INFINITE - : duration.count() / 1000; - DWORD wait_result = ::WaitForSingleObject(m_write_overlapped.hEvent, timeout); + DWORD timeout_msec = + timeout ? ceil(*timeout).count() : INFINITE; + DWORD wait_result = + ::WaitForSingleObject(m_write_overlapped.hEvent, timeout_msec); if (wait_result != WAIT_OBJECT_0) { // The operation probably failed. However, if it timed out, we need to // cancel the I/O. Between the time we returned from WaitForSingleObject @@ -359,15 +347,14 @@ Status PipeWindows::WriteWithTimeout(const void *buf, size_t size, failed = false; } if (failed) - return Status(failure_error, eErrorTypeWin32); + return Status(failure_error, eErrorTypeWin32).takeError(); } // Now we call GetOverlappedResult setting bWait to false, since we've // already waited as long as we're willing to. - if (!::GetOverlappedResult(m_write, &m_write_overlapped, &sys_bytes_write, + if (!::GetOverlappedResult(m_write, &m_write_overlapped, &bytes_written, FALSE)) - return Status(::GetLastError(), eErrorTypeWin32); + return Status(::GetLastError(), eErrorTypeWin32).takeError(); - bytes_written = sys_bytes_write; - return Status(); + return bytes_written; } diff --git a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp index d39ae79fd84f9..a7a04bb521697 100644 --- a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp +++ b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp @@ -30,7 +30,11 @@ #include "lldb/Utility/Log.h" #include "lldb/Utility/RegularExpression.h" #include "lldb/Utility/StreamString.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallString.h" +#include "llvm/Config/llvm-config.h" // for LLVM_ENABLE_ZLIB +#include "llvm/ADT/StringRef.h" +#include "llvm/Config/llvm-config.h" // for LLVM_ENABLE_ZLIB #include "llvm/Support/ScopedPrinter.h" #include "ProcessGDBRemoteLog.h" @@ -1147,15 +1151,25 @@ Status GDBRemoteCommunication::StartDebugserverProcess( if (socket_pipe.CanRead()) { char port_cstr[PATH_MAX] = {0}; port_cstr[0] = '\0'; - size_t num_bytes = sizeof(port_cstr); // Read port from pipe with 10 second timeout. - error = socket_pipe.ReadWithTimeout( - port_cstr, num_bytes, std::chrono::seconds{10}, num_bytes); + std::string port_str; + while (error.Success()) { + char buf[10]; + if (llvm::Expected num_bytes = socket_pipe.Read( + buf, std::size(buf), std::chrono::seconds(10))) { + if (*num_bytes == 0) + break; + port_str.append(buf, *num_bytes); + } else { + error = Status::FromError(num_bytes.takeError()); + } + } if (error.Success() && (port != nullptr)) { - assert(num_bytes > 0 && port_cstr[num_bytes - 1] == '\0'); + // NB: Deliberately using .c_str() to stop at embedded '\0's + llvm::StringRef port_ref = port_str.c_str(); uint16_t child_port = 0; // FIXME: improve error handling - llvm::to_integer(port_cstr, child_port); + llvm::to_integer(port_ref, child_port); if (*port == 0 || *port == child_port) { *port = child_port; LLDB_LOGF(log, diff --git a/lldb/source/Target/Process.cpp b/lldb/source/Target/Process.cpp index 294b7fbd1ee6b..70bbfb16ffc61 100644 --- a/lldb/source/Target/Process.cpp +++ b/lldb/source/Target/Process.cpp @@ -4760,15 +4760,16 @@ class IOHandlerProcessSTDIO : public IOHandler { } if (select_helper.FDIsSetRead(pipe_read_fd)) { - size_t bytes_read; // Consume the interrupt byte - Status error = m_pipe.Read(&ch, 1, bytes_read); - if (error.Success()) { + if (llvm::Expected bytes_read = m_pipe.Read(&ch, 1)) { if (ch == 'q') break; if (ch == 'i') if (StateIsRunningState(m_process->GetState())) m_process->SendAsyncInterrupt(); + } else { + LLDB_LOG_ERROR(GetLog(LLDBLog::Process), bytes_read.takeError(), + "Pipe read failed: {0}"); } } } @@ -4792,8 +4793,10 @@ class IOHandlerProcessSTDIO : public IOHandler { // deadlocking when the pipe gets fed up and blocks until data is consumed. if (m_is_running) { char ch = 'q'; // Send 'q' for quit - size_t bytes_written = 0; - m_pipe.Write(&ch, 1, bytes_written); + if (llvm::Error err = m_pipe.Write(&ch, 1).takeError()) { + LLDB_LOG_ERROR(GetLog(LLDBLog::Process), std::move(err), + "Pipe write failed: {0}"); + } } } @@ -4805,9 +4808,7 @@ class IOHandlerProcessSTDIO : public IOHandler { // m_process->SendAsyncInterrupt() from a much safer location in code. if (m_active) { char ch = 'i'; // Send 'i' for interrupt - size_t bytes_written = 0; - Status result = m_pipe.Write(&ch, 1, bytes_written); - return result.Success(); + return !errorToBool(m_pipe.Write(&ch, 1).takeError()); } else { // This IOHandler might be pushed on the stack, but not being run // currently so do the right thing if we aren't actively watching for diff --git a/lldb/tools/lldb-server/lldb-gdbserver.cpp b/lldb/tools/lldb-server/lldb-gdbserver.cpp index 563284730bc70..1ecbdad3ca5c0 100644 --- a/lldb/tools/lldb-server/lldb-gdbserver.cpp +++ b/lldb/tools/lldb-server/lldb-gdbserver.cpp @@ -167,27 +167,35 @@ void handle_launch(GDBRemoteCommunicationServerLLGS &gdb_server, } } -Status writeSocketIdToPipe(Pipe &port_pipe, llvm::StringRef socket_id) { - size_t bytes_written = 0; - // Write the port number as a C string with the NULL terminator. - return port_pipe.Write(socket_id.data(), socket_id.size() + 1, bytes_written); +static Status writeSocketIdToPipe(Pipe &port_pipe, + const std::string &socket_id) { + // NB: Include the nul character at the end. + llvm::StringRef buf(socket_id.data(), socket_id.size() + 1); + while (!buf.empty()) { + if (llvm::Expected written = + port_pipe.Write(buf.data(), buf.size())) + buf = buf.drop_front(*written); + else + return Status::FromError(written.takeError()); + } + return Status(); } Status writeSocketIdToPipe(const char *const named_pipe_path, llvm::StringRef socket_id) { Pipe port_name_pipe; // Wait for 10 seconds for pipe to be opened. - auto error = port_name_pipe.OpenAsWriterWithTimeout(named_pipe_path, false, - std::chrono::seconds{10}); - if (error.Fail()) - return error; - return writeSocketIdToPipe(port_name_pipe, socket_id); + if (llvm::Error err = port_name_pipe.OpenAsWriter(named_pipe_path, false, + std::chrono::seconds{10})) + return Status::FromError(std::move(err)); + + return writeSocketIdToPipe(port_name_pipe, socket_id.str()); } Status writeSocketIdToPipe(lldb::pipe_t unnamed_pipe, llvm::StringRef socket_id) { Pipe port_pipe{LLDB_INVALID_PIPE, unnamed_pipe}; - return writeSocketIdToPipe(port_pipe, socket_id); + return writeSocketIdToPipe(port_pipe, socket_id.str()); } void ConnectToRemote(MainLoop &mainloop, diff --git a/lldb/unittests/Host/PipeTest.cpp b/lldb/unittests/Host/PipeTest.cpp index 506f3d225a21e..a3a492648def6 100644 --- a/lldb/unittests/Host/PipeTest.cpp +++ b/lldb/unittests/Host/PipeTest.cpp @@ -10,9 +10,13 @@ #include "TestingSupport/SubsystemRAII.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" +#include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" +#include #include +#include #include +#include #include using namespace lldb_private; @@ -85,57 +89,53 @@ TEST_F(PipeTest, WriteWithTimeout) { char *read_ptr = reinterpret_cast(read_buf.data()); size_t write_bytes = 0; size_t read_bytes = 0; - size_t num_bytes = 0; // Write to the pipe until it is full. while (write_bytes + write_chunk_size <= buf_size) { - Status error = - pipe.WriteWithTimeout(write_ptr + write_bytes, write_chunk_size, - std::chrono::milliseconds(10), num_bytes); - if (error.Fail()) + llvm::Expected num_bytes = + pipe.Write(write_ptr + write_bytes, write_chunk_size, + std::chrono::milliseconds(10)); + if (num_bytes) { + write_bytes += *num_bytes; + } else { + ASSERT_THAT_ERROR(num_bytes.takeError(), llvm::Failed()); break; // The write buffer is full. - write_bytes += num_bytes; + } } ASSERT_LE(write_bytes + write_chunk_size, buf_size) << "Pipe buffer larger than expected"; // Attempt a write with a long timeout. auto start_time = std::chrono::steady_clock::now(); - ASSERT_THAT_ERROR(pipe.WriteWithTimeout(write_ptr + write_bytes, - write_chunk_size, - std::chrono::seconds(2), num_bytes) - .ToError(), - llvm::Failed()); + // TODO: Assert a specific error (EAGAIN?) here. + ASSERT_THAT_EXPECTED(pipe.Write(write_ptr + write_bytes, write_chunk_size, + std::chrono::seconds(2)), + llvm::Failed()); auto dur = std::chrono::steady_clock::now() - start_time; ASSERT_GE(dur, std::chrono::seconds(2)); // Attempt a write with a short timeout. start_time = std::chrono::steady_clock::now(); - ASSERT_THAT_ERROR( - pipe.WriteWithTimeout(write_ptr + write_bytes, write_chunk_size, - std::chrono::milliseconds(200), num_bytes) - .ToError(), - llvm::Failed()); + ASSERT_THAT_EXPECTED(pipe.Write(write_ptr + write_bytes, write_chunk_size, + std::chrono::milliseconds(200)), + llvm::Failed()); dur = std::chrono::steady_clock::now() - start_time; ASSERT_GE(dur, std::chrono::milliseconds(200)); ASSERT_LT(dur, std::chrono::seconds(2)); // Drain the pipe. while (read_bytes < write_bytes) { - ASSERT_THAT_ERROR( - pipe.ReadWithTimeout(read_ptr + read_bytes, write_bytes - read_bytes, - std::chrono::milliseconds(10), num_bytes) - .ToError(), - llvm::Succeeded()); - read_bytes += num_bytes; + llvm::Expected num_bytes = + pipe.Read(read_ptr + read_bytes, write_bytes - read_bytes, + std::chrono::milliseconds(10)); + ASSERT_THAT_EXPECTED(num_bytes, llvm::Succeeded()); + read_bytes += *num_bytes; } // Be sure the pipe is empty. - ASSERT_THAT_ERROR(pipe.ReadWithTimeout(read_ptr + read_bytes, 100, - std::chrono::milliseconds(10), - num_bytes) - .ToError(), - llvm::Failed()); + ASSERT_THAT_EXPECTED( + pipe.Read(read_ptr + read_bytes, 100, std::chrono::milliseconds(10)), + llvm::Failed()); // Check that we got what we wrote. ASSERT_EQ(write_bytes, read_bytes); @@ -144,9 +144,56 @@ TEST_F(PipeTest, WriteWithTimeout) { read_buf.begin())); // Write to the pipe again and check that it succeeds. - ASSERT_THAT_ERROR(pipe.WriteWithTimeout(write_ptr, write_chunk_size, - std::chrono::milliseconds(10), - num_bytes) - .ToError(), - llvm::Succeeded()); + ASSERT_THAT_EXPECTED( + pipe.Write(write_ptr, write_chunk_size, std::chrono::milliseconds(10)), + llvm::Succeeded()); +} + +TEST_F(PipeTest, ReadWithTimeout) { + Pipe pipe; + ASSERT_THAT_ERROR(pipe.CreateNew(false).ToError(), llvm::Succeeded()); + + char buf[100]; + // The pipe is initially empty. A polling read returns immediately. + ASSERT_THAT_EXPECTED(pipe.Read(buf, sizeof(buf), std::chrono::seconds(0)), + llvm::Failed()); + + // With a timeout, we should wait for at least this amount of time (but not + // too much). + auto start = std::chrono::steady_clock::now(); + ASSERT_THAT_EXPECTED( + pipe.Read(buf, sizeof(buf), std::chrono::milliseconds(200)), + llvm::Failed()); + auto dur = std::chrono::steady_clock::now() - start; + EXPECT_GT(dur, std::chrono::milliseconds(200)); + EXPECT_LT(dur, std::chrono::seconds(2)); + + // Write something into the pipe, and read it back. The blocking read call + // should return even though it hasn't filled the buffer. + llvm::StringRef hello_world("Hello world!"); + ASSERT_THAT_EXPECTED(pipe.Write(hello_world.data(), hello_world.size()), + llvm::HasValue(hello_world.size())); + ASSERT_THAT_EXPECTED(pipe.Read(buf, sizeof(buf)), + llvm::HasValue(hello_world.size())); + EXPECT_EQ(llvm::StringRef(buf, hello_world.size()), hello_world); + + // Now write something and try to read it in chunks. + memset(buf, 0, sizeof(buf)); + ASSERT_THAT_EXPECTED(pipe.Write(hello_world.data(), hello_world.size()), + llvm::HasValue(hello_world.size())); + ASSERT_THAT_EXPECTED(pipe.Read(buf, 4), llvm::HasValue(4)); + ASSERT_THAT_EXPECTED(pipe.Read(buf + 4, sizeof(buf) - 4), + llvm::HasValue(hello_world.size() - 4)); + EXPECT_EQ(llvm::StringRef(buf, hello_world.size()), hello_world); + + // A blocking read should wait until the data arrives. + memset(buf, 0, sizeof(buf)); + std::future> future_num_bytes = std::async( + std::launch::async, [&] { return pipe.Read(buf, sizeof(buf)); }); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + ASSERT_THAT_EXPECTED(pipe.Write(hello_world.data(), hello_world.size()), + llvm::HasValue(hello_world.size())); + ASSERT_THAT_EXPECTED(future_num_bytes.get(), + llvm::HasValue(hello_world.size())); + EXPECT_EQ(llvm::StringRef(buf, hello_world.size()), hello_world); } From 5f22195f6b1e986c963aaeb1965c25b28999af0e Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Thu, 31 Jul 2025 09:39:53 -0700 Subject: [PATCH 14/15] Fix the unit tests without Swift support --- lldb/unittests/Core/SwiftDemanglingPartsTest.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lldb/unittests/Core/SwiftDemanglingPartsTest.cpp b/lldb/unittests/Core/SwiftDemanglingPartsTest.cpp index f0cb565da028a..a781b99e02be0 100644 --- a/lldb/unittests/Core/SwiftDemanglingPartsTest.cpp +++ b/lldb/unittests/Core/SwiftDemanglingPartsTest.cpp @@ -6,15 +6,17 @@ // //===----------------------------------------------------------------------===// -#include "Plugins/Language/Swift/SwiftMangled.h" -#include "Plugins/LanguageRuntime/Swift/SwiftLanguageRuntime.h" #include "TestingSupport/TestUtilities.h" - #include "lldb/Core/DemangledNameInfo.h" #include "lldb/Core/Mangled.h" - +#include "lldb/Host/Config.h" #include "gtest/gtest.h" +#ifdef LLDB_ENABLE_SWIFT + +#include "Plugins/Language/Swift/SwiftMangled.h" +#include "Plugins/LanguageRuntime/Swift/SwiftLanguageRuntime.h" + using namespace lldb; using namespace lldb_private; @@ -1603,4 +1605,6 @@ TEST_P(SwiftDemanglingPartsTestFixture, SwiftDemanglingParts) { INSTANTIATE_TEST_SUITE_P( SwiftDemanglingPartsTests, SwiftDemanglingPartsTestFixture, - ::testing::ValuesIn(g_swift_demangling_parts_test_cases)); \ No newline at end of file + ::testing::ValuesIn(g_swift_demangling_parts_test_cases)); + +#endif From 8a845a69f13dc07556ae04d7f3ba126ba835bb39 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere Date: Mon, 23 Jun 2025 10:35:42 -0700 Subject: [PATCH 15/15] [lldb] Use `proc` instead of `pro` to avoid command ambiguity Use `proc` instead of `pro` to avoid ambiguity between the `process` and `protocol-server` command. (cherry picked from commit e391301e0e4d9183fe06e69602e87b0bc889aeda) --- .../API/functionalities/unwind/sigtramp/TestSigtrampUnwind.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lldb/test/API/functionalities/unwind/sigtramp/TestSigtrampUnwind.py b/lldb/test/API/functionalities/unwind/sigtramp/TestSigtrampUnwind.py index 3476647a12ecb..4736d48151793 100644 --- a/lldb/test/API/functionalities/unwind/sigtramp/TestSigtrampUnwind.py +++ b/lldb/test/API/functionalities/unwind/sigtramp/TestSigtrampUnwind.py @@ -44,7 +44,7 @@ def test(self): ) self.expect( - "pro handle -n false -p true -s false SIGUSR1", + "proc handle -n false -p true -s false SIGUSR1", "Have lldb pass SIGUSR1 signals", substrs=["SIGUSR1", "true", "false", "false"], )