@@ -40,6 +40,9 @@ WebSocketsServer::WebSocketsServer(uint16_t port, String origin, String protocol
4040
4141 _cbEvent = NULL ;
4242
43+ _httpHeaderValidationFunc = NULL ;
44+ _mandatoryHttpHeaders = NULL ;
45+ _mandatoryHttpHeaderCount = 0 ;
4346}
4447
4548
@@ -53,10 +56,14 @@ WebSocketsServer::~WebSocketsServer() {
5356 // TODO how to close server?
5457#endif
5558
59+ if (_mandatoryHttpHeaders)
60+ delete[] _mandatoryHttpHeaders;
61+
62+ _mandatoryHttpHeaderCount = 0 ;
5663}
5764
5865/* *
59- * calles to init the Websockets server
66+ * called to initialize the Websocket server
6067 */
6168void WebSocketsServer::begin (void ) {
6269 WSclient_t * client;
@@ -83,6 +90,7 @@ void WebSocketsServer::begin(void) {
8390 client->base64Authorization = " " ;
8491
8592 client->cWsRXsize = 0 ;
93+
8694#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC)
8795 client->cHttpLine = " " ;
8896#endif
@@ -118,7 +126,30 @@ void WebSocketsServer::onEvent(WebSocketServerEvent cbEvent) {
118126 _cbEvent = cbEvent;
119127}
120128
121- /* *
129+ /*
130+ * Sets the custom http header validator function
131+ * If this functionality is being used, call this function prior to calling WebSocketsServer::begin
132+ * @param httpHeaderValidationFunc WebSocketServerHttpHeaderValFunc ///< pointer to the custom http header validation function
133+ * @param mandatoryHttpHeaders const char* ///< the array of named http headers considered to be mandatory / must be present in order for websocket upgrade to succeed
134+ */
135+ void WebSocketsServer::onValidateHttpHeader (
136+ WebSocketServerHttpHeaderValFunc validationFunc,
137+ const char * mandatoryHttpHeaders[])
138+ {
139+ _httpHeaderValidationFunc = validationFunc;
140+
141+ if (_mandatoryHttpHeaders)
142+ delete[] _mandatoryHttpHeaders;
143+
144+ _mandatoryHttpHeaderCount = (sizeof (mandatoryHttpHeaders) / sizeof (char *));
145+ _mandatoryHttpHeaders = new String[_mandatoryHttpHeaderCount];
146+
147+ for (size_t i = 0 ; i < _mandatoryHttpHeaderCount; i++) {
148+ _mandatoryHttpHeaders[i] = mandatoryHttpHeaders[i];
149+ }
150+ }
151+
152+ /*
122153 * send text data to client
123154 * @param num uint8_t client id
124155 * @param payload uint8_t *
@@ -279,9 +310,8 @@ void WebSocketsServer::disconnect(uint8_t num) {
279310}
280311
281312
282-
283- /* *
284- * set the Authorizatio for the http request
313+ /*
314+ * set the Authorization for the http request
285315 * @param user const char *
286316 * @param password const char *
287317 */
@@ -388,7 +418,7 @@ bool WebSocketsServer::newClient(WEBSOCKETS_NETWORK_CLASS * TCPclient) {
388418 * @param payload uint8_t *
389419 * @param lenght size_t
390420 */
391- void WebSocketsServer::messageRecived (WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t lenght) {
421+ void WebSocketsServer::messageReceived (WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t lenght) {
392422 WStype_t type = WStype_ERROR;
393423
394424 switch (opcode) {
@@ -446,6 +476,7 @@ void WebSocketsServer::clientDisconnect(WSclient_t * client) {
446476 client->cIsWebsocket = false ;
447477
448478 client->cWsRXsize = 0 ;
479+
449480#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC)
450481 client->cHttpLine = " " ;
451482#endif
@@ -461,7 +492,7 @@ void WebSocketsServer::clientDisconnect(WSclient_t * client) {
461492/* *
462493 * get client state
463494 * @param client WSclient_t * ptr to the client struct
464- * @return true = conneted
495+ * @return true = connected
465496 */
466497bool WebSocketsServer::clientIsConnected (WSclient_t * client) {
467498
@@ -492,7 +523,7 @@ bool WebSocketsServer::clientIsConnected(WSclient_t * client) {
492523}
493524#if (WEBSOCKETS_NETWORK_TYPE != NETWORK_ESP8266_ASYNC)
494525/* *
495- * Handle incomming Connection Request
526+ * Handle incoming Connection Request
496527 */
497528void WebSocketsServer::handleNewClients (void ) {
498529
@@ -569,10 +600,22 @@ void WebSocketsServer::handleClientData(void) {
569600}
570601#endif
571602
603+ /*
604+ * returns an indicator whether the given named header exists in the configured _mandatoryHttpHeaders collection
605+ * @param headerName String ///< the name of the header being checked
606+ */
607+ bool WebSocketsServer::hasMandatoryHeader (String headerName) {
608+ for (size_t i = 0 ; i < _mandatoryHttpHeaderCount; i++) {
609+ if (_mandatoryHttpHeaders[i].equalsIgnoreCase (headerName))
610+ return true ;
611+ }
612+ return false ;
613+ }
572614
573615/* *
574- * handle the WebSocket header reading
575- * @param client WSclient_t * ptr to the client struct
616+ * handles http header reading for WebSocket upgrade
617+ * @param client WSclient_t * ///< pointer to the client struct
618+ * @param headerLine String ///< the header being read / processed
576619 */
577620void WebSocketsServer::handleHeader (WSclient_t * client, String * headerLine) {
578621
@@ -581,10 +624,16 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
581624 if (headerLine->length () > 0 ) {
582625 DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] RX: %s\n " , client->num , headerLine->c_str ());
583626
584- // websocket request starts allways with GET see rfc6455
627+ // websocket requests always start with GET see rfc6455
585628 if (headerLine->startsWith (" GET " )) {
629+
586630 // cut URL out
587631 client->cUrl = headerLine->substring (4 , headerLine->indexOf (' ' , 4 ));
632+
633+ // reset non-websocket http header validation state for this client
634+ client->cHttpHeadersValid = true ;
635+ client->cMandatoryHeadersCount = 0 ;
636+
588637 } else if (headerLine->indexOf (' :' )) {
589638 String headerName = headerLine->substring (0 , headerLine->indexOf (' :' ));
590639 String headerValue = headerLine->substring (headerLine->indexOf (' :' ) + 2 );
@@ -609,7 +658,13 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
609658 client->cExtensions = headerValue;
610659 } else if (headerName.equalsIgnoreCase (" Authorization" )) {
611660 client->base64Authorization = headerValue;
661+ } else {
662+ client->cHttpHeadersValid &= execHttpHeaderValidation (headerName, headerValue);
663+ if (_mandatoryHttpHeaderCount > 0 && hasMandatoryHeader (headerName)) {
664+ client->cMandatoryHeadersCount ++;
665+ }
612666 }
667+
613668 } else {
614669 DEBUG_WEBSOCKETS (" [WS-Client][handleHeader] Header error (%s)\n " , headerLine->c_str ());
615670 }
@@ -619,8 +674,8 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
619674 client->tcp ->readStringUntil (' \n ' , &(client->cHttpLine ), std::bind (&WebSocketsServer::handleHeader, this , client, &(client->cHttpLine )));
620675#endif
621676 } else {
622- DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] Header read fin.\n " , client->num );
623677
678+ DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] Header read fin.\n " , client->num );
624679 DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cURL: %s\n " , client->num , client->cUrl .c_str ());
625680 DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cIsUpgrade: %d\n " , client->num , client->cIsUpgrade );
626681 DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cIsWebsocket: %d\n " , client->num , client->cIsWebsocket );
@@ -629,6 +684,8 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
629684 DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cExtensions: %s\n " , client->num , client->cExtensions .c_str ());
630685 DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cVersion: %d\n " , client->num , client->cVersion );
631686 DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - base64Authorization: %s\n " , client->num , client->base64Authorization .c_str ());
687+ DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cHttpHeadersValid: %d\n " , client->num , client->cHttpHeadersValid );
688+ DEBUG_WEBSOCKETS (" [WS-Server][%d][handleHeader] - cMandatoryHeadersCount: %d\n " , client->num , client->cMandatoryHeadersCount );
632689
633690 bool ok = (client->cIsUpgrade && client->cIsWebsocket );
634691
@@ -642,6 +699,12 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) {
642699 if (client->cVersion != 13 ) {
643700 ok = false ;
644701 }
702+ if (!client->cHttpHeadersValid ) {
703+ ok = false ;
704+ }
705+ if (client->cMandatoryHeadersCount != _mandatoryHttpHeaderCount) {
706+ ok = false ;
707+ }
645708 }
646709
647710 if (_base64Authorization.length () > 0 ) {
0 commit comments