diff --git a/easywsclient.cpp b/easywsclient.cpp index 0c98098..dec6ca9 100644 --- a/easywsclient.cpp +++ b/easywsclient.cpp @@ -1,4 +1,5 @@ +#include #ifdef _WIN32 #if defined(_MSC_VER) && !defined(_CRT_SECURE_NO_WARNINGS) #define _CRT_SECURE_NO_WARNINGS // _CRT_SECURE_NO_WARNINGS for sscanf errors in MSVC2013 Express @@ -119,7 +120,7 @@ class _DummyWebSocket : public easywsclient::WebSocket void sendBinary(const std::string& message) { } void sendBinary(const std::vector& message) { } void sendPing() { } - void close() { } + void close() { } readyStateValues getReadyState() const { return CLOSED; } void _dispatch(Callback_Imp & callable) { } void _dispatchBinary(BytesCallback_Imp& callable) { } @@ -340,7 +341,7 @@ class _RealWebSocket : public easywsclient::WebSocket // We got a whole message, now do something with it: if (false) { } else if ( - ws.opcode == wsheader_type::TEXT_FRAME + ws.opcode == wsheader_type::TEXT_FRAME || ws.opcode == wsheader_type::BINARY_FRAME || ws.opcode == wsheader_type::CONTINUATION ) { @@ -454,7 +455,7 @@ class _RealWebSocket : public easywsclient::WebSocket }; -easywsclient::WebSocket::pointer from_url(const std::string& url, bool useMask, const std::string& origin) { +easywsclient::WebSocket::pointer from_url(const std::string& url, bool useMask, const std::string& origin, const std::map& extraHeaders) { char host[512]; int port; char path[512]; @@ -506,6 +507,12 @@ easywsclient::WebSocket::pointer from_url(const std::string& url, bool useMask, if (!origin.empty()) { snprintf(line, 1024, "Origin: %s\r\n", origin.c_str()); ::send(sockfd, line, strlen(line), 0); } + if (!extraHeaders.empty()) { + for (auto header : extraHeaders) { + snprintf(line, 1024, "%s: %s\r\n", header.first.c_str(), header.second.c_str()); + ::send(sockfd, line, strlen(line), 0); + } + } snprintf(line, 1024, "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"); ::send(sockfd, line, strlen(line), 0); snprintf(line, 1024, "Sec-WebSocket-Version: 13\r\n"); ::send(sockfd, line, strlen(line), 0); snprintf(line, 1024, "\r\n"); ::send(sockfd, line, strlen(line), 0); @@ -543,12 +550,20 @@ WebSocket::pointer WebSocket::create_dummy() { } -WebSocket::pointer WebSocket::from_url(const std::string& url, const std::string& origin) { - return ::from_url(url, true, origin); +WebSocket::pointer WebSocket::from_url( + const std::string& url, + const std::string& origin, + const std::map& extraHeaders + ) { + return ::from_url(url, true, origin, extraHeaders); } -WebSocket::pointer WebSocket::from_url_no_mask(const std::string& url, const std::string& origin) { - return ::from_url(url, false, origin); +WebSocket::pointer WebSocket::from_url_no_mask( + const std::string& url, + const std::string& origin, + const std::map& extraHeaders + ) { + return ::from_url(url, false, origin, extraHeaders); } diff --git a/easywsclient.hpp b/easywsclient.hpp index 08c4a7b..dd3edf5 100644 --- a/easywsclient.hpp +++ b/easywsclient.hpp @@ -8,6 +8,7 @@ // wget https://raw.github.com/dhbaird/easywsclient/master/easywsclient.hpp // wget https://raw.github.com/dhbaird/easywsclient/master/easywsclient.cpp +#include #include #include @@ -23,8 +24,16 @@ class WebSocket { // Factories: static pointer create_dummy(); - static pointer from_url(const std::string& url, const std::string& origin = std::string()); - static pointer from_url_no_mask(const std::string& url, const std::string& origin = std::string()); + static pointer from_url( + const std::string& url, + const std::string& origin = std::string(), + const std::map& extraHeaders = std::map() + ); + static pointer from_url_no_mask( + const std::string& url, + const std::string& origin = std::string(), + const std::map& extraHeaders = std::map() + ); // Interfaces: virtual ~WebSocket() { } diff --git a/example-client-cpp11.cpp b/example-client-cpp11.cpp index 5298c36..d2480c0 100644 --- a/example-client-cpp11.cpp +++ b/example-client-cpp11.cpp @@ -25,7 +25,11 @@ int main() } #endif - std::unique_ptr ws(WebSocket::from_url("ws://localhost:8126/foo")); + std::map headers; + headers.insert(std::make_pair("Authorization", "Bearer 123")); + std::string origin = "https://example.com"; + + std::unique_ptr ws(WebSocket::from_url("ws://localhost:8126/foo", origin, headers)); assert(ws); ws->send("goodbye"); ws->send("hello"); diff --git a/example-client.cpp b/example-client.cpp index abc8159..cca1f4c 100644 --- a/example-client.cpp +++ b/example-client.cpp @@ -29,8 +29,11 @@ int main() return 1; } #endif + std::map headers; + headers.insert(std::make_pair("Authorization", "Bearer 123")); + std::string origin = "https://example.com"; - ws = WebSocket::from_url("ws://localhost:8126/foo"); + ws = WebSocket::from_url("ws://localhost:8126/foo", origin, headers); assert(ws); ws->send("goodbye"); ws->send("hello");