diff --git a/MD5.c b/MD5.c index 5edaf5c..7d2a06f 100644 --- a/MD5.c +++ b/MD5.c @@ -20,7 +20,7 @@ * These notices must be retained in any copies of any part of this * documentation and/or software. */ - +#ifndef ESP32 #include "global.h" #include "MD5.h" @@ -300,4 +300,5 @@ void MD5(unsigned char strInputString[], unsigned char md5Digest[], unsigned int MD5Update(&ctx, strInputString, len); MD5Final(md5Digest, &ctx); -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/MD5.h b/MD5.h index c434fca..94475b5 100644 --- a/MD5.h +++ b/MD5.h @@ -19,7 +19,8 @@ * These notices must be retained in any copies of any part of this * documentation and/or software. */ - +#ifndef MD5_H +#define MD5_H /* MD5 context. */ typedef struct { UINT4 state[4]; /* state (ABCD) */ @@ -32,4 +33,5 @@ void MD5Update (MD5_CTX *, unsigned char *, unsigned int); void MD5Final (unsigned char [16], MD5_CTX *); /* Function used by Websockets implementation */ -void MD5 (unsigned char [], unsigned char [], unsigned int); \ No newline at end of file +void MD5 (unsigned char [], unsigned char [], unsigned int); +#endif /*MD5_H*/ \ No newline at end of file diff --git a/README.md b/README.md index 2eea2f0..4c10d6a 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,6 @@ Inside of the WebSocketServer class there is a compiler directive to turn on sup Because of limitations of the current Arduino platform (Uno at the time of this writing), this library does not support messages larger than 65535 characters. In addition, this library only supports single-frame text frames. It currently does not recognize continuation frames, binary frames, or ping/pong frames. ### Credits -Thank you to github user ejeklint for the excellent starting point for this library. From his original Hixie76-only code I was able to add support for RFC 6455 and create the WebSocket client. +Thank you to github user ejeklint and brandenhall for the excellent starting point for this library. From ejeklint's original Hixie76-only code brandenhall was able to add support for RFC 6455 and create the WebSocket client. Then I only had to improve stability and speed. -- Branden \ No newline at end of file +- Pablo diff --git a/WebSocketClient.cpp b/WebSocketClient.cpp index bad1b9f..8100ed0 100644 --- a/WebSocketClient.cpp +++ b/WebSocketClient.cpp @@ -4,10 +4,33 @@ #include "WebSocketClient.h" #include "sha1.h" -#include "base64.h" +#include "Base64.h" + +WebSocketClient::WebSocketClient(char *WsPath, char *WsHost, char *WsHeaders, char *WsProtocol){ + path = WsPath; + host = WsHost; + headers = WsHeaders; + protocol = WsProtocol; +#ifdef WS_BUFFERED_SEND + bufferIndex = 0; +#endif +} + +void WebSocketClient::setHost(char* WsHost) { + host = WsHost; +} +void WebSocketClient::setProtocol(char* WsProtocol) { + protocol = WsProtocol; +} +void WebSocketClient::setPath(char* WsPath) { + path = WsPath; +} +void WebSocketClient::setHeaders(char* WsHeaders) { + headers = WsHeaders; +} -bool WebSocketClient::handshake(Client &client) { +int WebSocketClient::handshake(Client &client) { socket_client = &client; @@ -17,12 +40,13 @@ bool WebSocketClient::handshake(Client &client) { #ifdef DEBUGGING Serial.println(F("Client connected")); #endif - if (analyzeRequest()) { + int res = analyzeRequest(); + if (res==101) { #ifdef DEBUGGING Serial.println(F("Websocket established")); #endif - return true; + return res; } else { // Might just need to break until out of socket_client loop. @@ -31,19 +55,21 @@ bool WebSocketClient::handshake(Client &client) { #endif disconnectStream(); - return false; + return res; } } else { - return false; + return -1; } } -bool WebSocketClient::analyzeRequest() { +int WebSocketClient::analyzeRequest() { String temp; int bite; bool foundupgrade = false; unsigned long intkey[2]; + String sAnswerCode; + int answerCode=0; String serverKey; char keyStart[17]; char b64Key[25]; @@ -64,7 +90,7 @@ bool WebSocketClient::analyzeRequest() { #ifdef DEBUGGING Serial.println(F("Sending websocket upgrade headers")); #endif - +#ifndef WS_BUFFERED_SEND socket_client->print(F("GET ")); socket_client->print(path); socket_client->print(F(" HTTP/1.1\r\n")); @@ -72,39 +98,118 @@ bool WebSocketClient::analyzeRequest() { socket_client->print(F("Connection: Upgrade\r\n")); socket_client->print(F("Host: ")); socket_client->print(host); - socket_client->print(CRLF); + socket_client->print(F(CRLF)); + if(headers) { + socket_client->print(headers); + socket_client->print(F(CRLF)); + } socket_client->print(F("Sec-WebSocket-Key: ")); socket_client->print(key); - socket_client->print(CRLF); - socket_client->print(F("Sec-WebSocket-Protocol: ")); - socket_client->print(protocol); - socket_client->print(CRLF); + socket_client->print(F(CRLF)); + if(protocol!=NULL) { + socket_client->print(F("Sec-WebSocket-Protocol: ")); + socket_client->print(protocol); + socket_client->print(F(CRLF)); + } socket_client->print(F("Sec-WebSocket-Version: 13\r\n")); - socket_client->print(CRLF); + socket_client->print(F(CRLF)); +#else + bufferIndex = 0; + strncpy_P((char *)&buffer[bufferIndex],(const char *)F("GET "), 4); + bufferIndex+=4; + strcpy((char *)&buffer[bufferIndex],path); + bufferIndex+=strlen(path); + strncpy_P((char *)&buffer[bufferIndex],(const char *)F(" HTTP/1.1\r\n"), 11); + bufferIndex+=11; + strncpy_P((char *)&buffer[bufferIndex],(const char *)F("Upgrade: websocket\r\n"), 20); + bufferIndex+=20; + strncpy_P((char *)&buffer[bufferIndex],(const char *)F("Connection: Upgrade\r\n"), 21); + bufferIndex+=21; + strncpy_P((char *)&buffer[bufferIndex],(const char *)F("Host: "), 6); + bufferIndex+=6; + strcpy((char *)&buffer[bufferIndex],host); + bufferIndex+=strlen(host); + strncpy_P((char *)&buffer[bufferIndex],(const char *)F(CRLF), 2); + bufferIndex+=2; + if(headers) { + strcpy((char *)&buffer[bufferIndex],headers); + bufferIndex+=strlen(headers); + strncpy_P((char *)&buffer[bufferIndex],(const char *)F(CRLF), 2); + bufferIndex+=2; + } + strncpy_P((char *)&buffer[bufferIndex],(const char *)F("Sec-WebSocket-Key: "), 19); + bufferIndex+=19; + strcpy((char *)&buffer[bufferIndex],&key[0]); + bufferIndex+=key.length(); + strncpy_P((char *)&buffer[bufferIndex],(const char *)F(CRLF), 2); + bufferIndex+=2; + if(protocol!=NULL) { + strncpy_P((char *)&buffer[bufferIndex],(const char *)F("Sec-WebSocket-Protocol: "), 24); + bufferIndex+=24; + strcpy((char *)&buffer[bufferIndex],protocol); + bufferIndex+=strlen(protocol); + strncpy_P((char *)&buffer[bufferIndex],(const char *)F(CRLF), 2); + bufferIndex+=2; + } + strncpy_P((char *)&buffer[bufferIndex],(const char *)F("Sec-WebSocket-Version: 13\r\n"), 27); + bufferIndex+=27; + strncpy_P((char *)&buffer[bufferIndex],(const char *)F(CRLF), 2); + bufferIndex+=2; + if(socket_client->write(buffer, bufferIndex)){ +#ifdef DEBUGGING + Serial.print("Sending: "); + int i; + for(i=0; iconnected() && !socket_client->available()) { delay(100); +#ifdef DEBUGGING Serial.println("Waiting..."); +#endif } - +#ifdef DEBUGGING + if(!socket_client->connected()) Serial.println("Error. Broken connection"); +#endif // TODO: More robust string extraction while ((bite = socket_client->read()) != -1) { temp += (char)bite; - if ((char)bite == '\n') { + if(temp.startsWith("\r\n")) { +#ifdef DEBUGGING + Serial.println("End of headers"); +#endif + break; + } else if (!foundupgrade && temp.startsWith("Upgrade: websocket")) { + foundupgrade = true; + } else if (temp.startsWith("Sec-WebSocket-Accept: ")) { + serverKey = temp.substring(22,temp.length() - 2); // Don't save last CR+LF + } else if(temp.startsWith("HTTP/1.1 ")){ + int i; + for(i=9; i< temp.length(); i++) if (!isDigit(temp[i])) break; + sAnswerCode = temp.substring(9,i); + answerCode = atoi(&sAnswerCode[0]); #ifdef DEBUGGING - Serial.print("Got Header: " + temp); + Serial.print("Answer Code: "); + Serial.println(answerCode); #endif - if (!foundupgrade && temp.startsWith("Upgrade: websocket")) { - foundupgrade = true; - } else if (temp.startsWith("Sec-WebSocket-Accept: ")) { - serverKey = temp.substring(22,temp.length() - 2); // Don't save last CR+LF - } + } +#ifdef DEBUGGING + Serial.print("Got Header: " + temp); +#endif temp = ""; } @@ -130,34 +235,25 @@ bool WebSocketClient::analyzeRequest() { base64_encode(b64Result, result, 20); // if the keys match, good to go - return serverKey.equals(String(b64Result)); + if (answerCode==101 && serverKey.equals(String(b64Result)) || answerCode!=101 && answerCode!=0) return answerCode; + else return -1; } - -bool WebSocketClient::handleStream(String& data, uint8_t *opcode) { - uint8_t msgtype; - uint8_t bite; - unsigned int length; - uint8_t mask[4]; - uint8_t index; - unsigned int i; - bool hasMask = false; - - if (!socket_client->connected() || !socket_client->available()) - { +bool WebSocketClient::handleMessageHeader(uint8_t *msgtype, unsigned int *length, bool *hasMask, uint8_t *mask, uint8_t *opcode) { + if (!socket_client->connected() || !socket_client->available()) { return false; } - msgtype = timedRead(); + *msgtype = timedRead(); if (!socket_client->connected()) { return false; } - length = timedRead(); + *length = timedRead(); - if (length & WS_MASK) { - hasMask = true; - length = length & ~WS_MASK; + if (*length & WS_MASK) { + *hasMask = true; + *length = *length & ~WS_MASK; } @@ -165,27 +261,25 @@ bool WebSocketClient::handleStream(String& data, uint8_t *opcode) { return false; } - index = 6; - - if (length == WS_SIZE16) { - length = timedRead() << 8; + if (*length == WS_SIZE16) { + *length = timedRead() << 8; if (!socket_client->connected()) { return false; } - length |= timedRead(); + *length |= timedRead(); if (!socket_client->connected()) { return false; } - } else if (length == WS_SIZE64) { + } else if (*length == WS_SIZE64) { #ifdef DEBUGGING Serial.println(F("No support for over 16 bit sized messages")); #endif return false; } - if (hasMask) { + if (*hasMask) { // get the mask mask[0] = timedRead(); if (!socket_client->connected()) { @@ -208,23 +302,34 @@ bool WebSocketClient::handleStream(String& data, uint8_t *opcode) { return false; } } - - data = ""; - + if (opcode != NULL) { - *opcode = msgtype & ~WS_FIN; + *opcode = *msgtype & ~WS_FIN; } - + + return true; +} + +bool WebSocketClient::handleStream(String& data, uint8_t *opcode) { + uint8_t msgtype; + unsigned int length; + uint8_t mask[4]; + bool hasMask = false; + + if(!handleMessageHeader(&msgtype, &length, &hasMask, mask, opcode)) return false; + + data = ""; + if (hasMask) { - for (i=0; iconnected()) { return false; } } } else { - for (i=0; iconnected()) { return false; @@ -235,6 +340,35 @@ bool WebSocketClient::handleStream(String& data, uint8_t *opcode) { return true; } +bool WebSocketClient::handleStream(char *data, unsigned int dataLen, uint8_t *opcode) { + uint8_t msgtype; + unsigned int length; + uint8_t mask[4]; + bool hasMask = false; + + if(!handleMessageHeader(&msgtype, &length, &hasMask, mask, opcode)) return false; + + int i; + int limit = length>dataLen?dataLen:length; + if (hasMask) { + for (i=0; iconnected()) { + return false; + } + } + } else { + for (i=0; iconnected()) { + return false; + } + } + } + data[i] = '\0'; + return true; +} + void WebSocketClient::disconnectStream() { #ifdef DEBUGGING Serial.println(F("Terminating socket")); @@ -248,14 +382,24 @@ void WebSocketClient::disconnectStream() { socket_client->stop(); } +int WebSocketClient::connected(void) { + return socket_client->connected(); +} + bool WebSocketClient::getData(String& data, uint8_t *opcode) { return handleStream(data, opcode); } +bool WebSocketClient::getData(char *data, unsigned int dataLen, uint8_t *opcode) { + return handleStream(data, dataLen, opcode); +} + void WebSocketClient::sendData(const char *str, uint8_t opcode) { #ifdef DEBUGGING - Serial.print(F("Sending data: ")); - Serial.println(str); + if((char)str[0]!=0) { + Serial.print(F("Sending data: ")); + Serial.println(str); + } #endif if (socket_client->connected()) { sendEncodedData(str, opcode); @@ -264,8 +408,10 @@ void WebSocketClient::sendData(const char *str, uint8_t opcode) { void WebSocketClient::sendData(String str, uint8_t opcode) { #ifdef DEBUGGING - Serial.print(F("Sending data: ")); - Serial.println(str); + if((char)str[0]!=0) { + Serial.print(F("Sending data: ")); + Serial.println(str); + } #endif if (socket_client->connected()) { sendEncodedData(str, opcode); @@ -276,39 +422,108 @@ int WebSocketClient::timedRead() { while (!socket_client->available()) { delay(20); } - +#ifdef DEBUGGING + char c = socket_client->read(); + Serial.println(c); + return c; +#else return socket_client->read(); +#endif +} + +#ifdef WS_BUFFERED_SEND +int WebSocketClient::bufferedSend(uint8_t c) { + if(bufferIndex0) { + if(socket_client->write(buffer, bufferIndex)) { + bufferIndex = 0; + return 1; + } else { + //Error sending. Most probable thing is socket disconnection + //Serial.println("######################################################################SEND FAILED"); + return 0; + } + } +} +#endif + void WebSocketClient::sendEncodedData(char *str, uint8_t opcode) { - uint8_t mask[4]; + uint8_t header[8]; int size = strlen(str); + int i = 0; // Opcode; final fragment - socket_client->write(opcode | WS_FIN); + header[i++] = opcode | WS_FIN; // NOTE: no support for > 16-bit sized messages if (size > 125) { - socket_client->write(WS_SIZE16 | WS_MASK); - socket_client->write((uint8_t) (size >> 8)); - socket_client->write((uint8_t) (size & 0xFF)); + header[i++] = WS_SIZE16 | WS_MASK; + header[i++] = (uint8_t) (size >> 8); + header[i++] = (uint8_t) (size & 0xFF); } else { - socket_client->write((uint8_t) size | WS_MASK); + header[i++] = (uint8_t) size | WS_MASK; } +#ifdef DEBUGGING + Serial.print("Sending message. Header: "); + int j; + for(j=0; jwrite(mask[0]); - socket_client->write(mask[1]); - socket_client->write(mask[2]); - socket_client->write(mask[3]); - - for (int i=0; iwrite(str[i] ^ mask[i % 4]); +#ifdef WS_BUFFERED_SEND +#ifdef DEBUGGING + Serial.print("Sending: "); +#endif + for(int k=0; kwrite(header, i)) { + for (int j=0; jwrite(c); + } } +#ifdef DEBUGGING + Serial.println(); +#endif +#endif /*WS_BUFFERED_SEND*/ } void WebSocketClient::sendEncodedData(String str, uint8_t opcode) { diff --git a/WebSocketClient.h b/WebSocketClient.h index 89b7c23..d559c02 100644 --- a/WebSocketClient.h +++ b/WebSocketClient.h @@ -45,9 +45,16 @@ Currently based off of "The Web Socket protocol" draft (v 75): #include #include -#include "String.h" +#include "string.h" #include "Client.h" +//Uncoment the following line for debug output ot serial port +//#define DEBUGGING + +#if defined ESP8266 || defined ESP32 || defined ARDUINO_SAMD_MKR1000 +#define WS_BUFFERED_SEND +#endif + // CRLF characters to terminate lines/handshakes in headers. #define CRLF "\r\n" @@ -64,9 +71,15 @@ Currently based off of "The Web Socket protocol" draft (v 75): // Don't allow the client to send big frames of data. This will flood the Arduinos // memory and might even crash it. #ifndef MAX_FRAME_LENGTH + +#if defined ESP8266 || defined ESP32 || defined ARDUINO_SAMD_MKR1000 +#define MAX_FRAME_LENGTH 2048 +#else #define MAX_FRAME_LENGTH 256 #endif +#endif + #define SIZE(array) (sizeof(array) / sizeof(*array)) // WebSocket protocol constants @@ -85,34 +98,55 @@ Currently based off of "The Web Socket protocol" draft (v 75): class WebSocketClient { public: - + WebSocketClient(char *WsPath = NULL, char *WsHost = NULL, char *WsHeaders = NULL, char *WsProtocol = NULL); + void setPath(char* WsPath); + void setHeaders(char *WsHeaders); + void setHost(char * WsHost); + void setProtocol(char * WsProtocol); // Handle connection requests to validate and process/refuse // connections. - bool handshake(Client &client); - + int handshake(Client &client); + //Check if socket os connected + int connected(); // Get data off of the stream bool getData(String& data, uint8_t *opcode = NULL); + bool getData(char *data, unsigned int dataLen, uint8_t *opcode = NULL); // Write data to the stream void sendData(const char *str, uint8_t opcode = WS_OPCODE_TEXT); void sendData(String str, uint8_t opcode = WS_OPCODE_TEXT); - char *path; - char *host; - char *protocol; +#ifdef WS_BUFFERED_SEND + int process(void); +#endif + + void disconnect(void) {disconnectStream();}; + private: Client *socket_client; - unsigned long _startMillis; - + const char *socket_urlPrefix; + char *path; + char *host; + char *protocol; + char *headers; + +#ifdef WS_BUFFERED_SEND + uint8_t buffer[MAX_FRAME_LENGTH]; + unsigned int bufferIndex; + int bufferedSend(uint8_t c); +#endif // Discovers if the client's header is requesting an upgrade to a // websocket connection. - bool analyzeRequest(); + int analyzeRequest(); + + bool handleMessageHeader(uint8_t *msgtype, unsigned int *length, bool *hasMask, uint8_t *mask, uint8_t *opcode); bool handleStream(String& data, uint8_t *opcode); - + bool handleStream(char *data, unsigned int dataLen, uint8_t *opcode); + // Disconnect user gracefully. void disconnectStream(); @@ -124,4 +158,4 @@ class WebSocketClient { -#endif \ No newline at end of file +#endif diff --git a/WebSocketServer.cpp b/WebSocketServer.cpp index bee55c5..cf2677c 100644 --- a/WebSocketServer.cpp +++ b/WebSocketServer.cpp @@ -5,11 +5,15 @@ #include "WebSocketServer.h" #ifdef SUPPORT_HIXIE_76 +#ifndef ESP32 #include "MD5.c" +#else +#include +#endif #endif #include "sha1.h" -#include "base64.h" +#include "Base64.h" bool WebSocketServer::handshake(Client &client) { diff --git a/WebSocketServer.h b/WebSocketServer.h index 6dcf052..078ab19 100644 --- a/WebSocketServer.h +++ b/WebSocketServer.h @@ -45,7 +45,7 @@ Currently based off of "The Web Socket protocol" draft (v 75): #include #include -#include "String.h" +#include "string.h" #include "Server.h" #include "Client.h" @@ -115,4 +115,4 @@ class WebSocketServer { -#endif \ No newline at end of file +#endif diff --git a/sha1.cpp b/sha1.cpp index 770f6f5..aa404cb 100755 --- a/sha1.cpp +++ b/sha1.cpp @@ -1,6 +1,8 @@ #include +#ifdef __AVR__ #include #include +#endif #include "sha1.h" #define SHA1_K0 0x5a827999 @@ -8,7 +10,7 @@ #define SHA1_K40 0x8f1bbcdc #define SHA1_K60 0xca62c1d6 -uint8_t sha1InitState[] PROGMEM = { +const uint8_t sha1InitState[] PROGMEM = { 0x01,0x23,0x45,0x67, // H0 0x89,0xab,0xcd,0xef, // H1 0xfe,0xdc,0xba,0x98, // H2