diff --git a/components/tcp_transport/CMakeLists.txt b/components/tcp_transport/CMakeLists.txt index 97f21bc9..db4afb87 100644 --- a/components/tcp_transport/CMakeLists.txt +++ b/components/tcp_transport/CMakeLists.txt @@ -1,11 +1,8 @@ -set(COMPONENT_SRCS "transport.c" - "transport_ssl.c" - "transport_tcp.c" - "transport_ws.c" - "transport_utils.c") - -set(COMPONENT_ADD_INCLUDEDIRS "include") - -set(COMPONENT_REQUIRES "lwip" "esp-tls") - -register_component() +idf_component_register(SRCS "transport.c" + "transport_ssl.c" + "transport_tcp.c" + "transport_ws.c" + "transport_utils.c" + INCLUDE_DIRS "include" + PRIV_INCLUDE_DIRS "private_include" + REQUIRES lwip esp-tls) diff --git a/components/tcp_transport/component.mk b/components/tcp_transport/component.mk index d492b815..c81b5477 100644 --- a/components/tcp_transport/component.mk +++ b/components/tcp_transport/component.mk @@ -1,7 +1,2 @@ -# -# Component Makefile -# -# (Uses default behaviour of compiling all source files in directory, adding 'include' to include path.) - -COMPONENT_SRCDIRS := . COMPONENT_ADD_INCLUDEDIRS := include +COMPONENT_PRIV_INCLUDEDIRS := private_include \ No newline at end of file diff --git a/components/tcp_transport/include/esp_transport.h b/components/tcp_transport/include/esp_transport.h index e163a010..39e694f0 100644 --- a/components/tcp_transport/include/esp_transport.h +++ b/components/tcp_transport/include/esp_transport.h @@ -22,7 +22,7 @@ extern "C" { #endif -typedef struct esp_transport_list_t* esp_transport_list_handle_t; +typedef struct esp_transport_internal* esp_transport_list_handle_t; typedef struct esp_transport_item_t* esp_transport_handle_t; typedef int (*connect_func)(esp_transport_handle_t t, const char *host, int port, int timeout_ms); @@ -33,12 +33,14 @@ typedef int (*poll_func)(esp_transport_handle_t t, int timeout_ms); typedef int (*connect_async_func)(esp_transport_handle_t t, const char *host, int port, int timeout_ms); typedef esp_transport_handle_t (*payload_transfer_func)(esp_transport_handle_t); +typedef struct esp_tls_last_error* esp_tls_error_handle_t; + /** * @brief Create transport list * * @return A handle can hold all transports */ -esp_transport_list_handle_t esp_transport_list_init(); +esp_transport_list_handle_t esp_transport_list_init(void); /** * @brief Cleanup and free all transports, include itself, @@ -91,7 +93,7 @@ esp_transport_handle_t esp_transport_list_get_transport(esp_transport_list_handl * * @return The transport handle */ -esp_transport_handle_t esp_transport_init(); +esp_transport_handle_t esp_transport_init(void); /** * @brief Cleanup and free memory the transport @@ -298,6 +300,21 @@ esp_err_t esp_transport_set_async_connect_func(esp_transport_handle_t t, connect */ esp_err_t esp_transport_set_parent_transport_func(esp_transport_handle_t t, payload_transfer_func _parent_transport); +/** + * @brief Returns esp_tls error handle. + * Warning: The returned pointer is valid only as long as esp_transport_handle_t exists. Once transport + * handle gets destroyed, this value (esp_tls_error_handle_t) is freed automatically. + * + * @param[in] A transport handle + * + * @return + * - valid pointer of esp_error_handle_t + * - NULL if invalid transport handle + */ +esp_tls_error_handle_t esp_transport_get_error_handle(esp_transport_handle_t t); + + + #ifdef __cplusplus } #endif diff --git a/components/tcp_transport/include/esp_transport_ssl.h b/components/tcp_transport/include/esp_transport_ssl.h index 34045ae7..a83e9388 100644 --- a/components/tcp_transport/include/esp_transport_ssl.h +++ b/components/tcp_transport/include/esp_transport_ssl.h @@ -16,6 +16,7 @@ #define _ESP_TRANSPORT_SSL_H_ #include "esp_transport.h" +#include "esp_tls.h" #ifdef __cplusplus extern "C" { @@ -27,7 +28,7 @@ extern "C" { * * @return the allocated esp_transport_handle_t, or NULL if the handle can not be allocated */ -esp_transport_handle_t esp_transport_ssl_init(); +esp_transport_handle_t esp_transport_ssl_init(void); /** * @brief Set SSL certificate data (as PEM format). @@ -40,6 +41,24 @@ esp_transport_handle_t esp_transport_ssl_init(); */ void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, int len); +/** + * @brief Set SSL certificate data (as DER format). + * Note that, this function stores the pointer to data, rather than making a copy. + * So this data must remain valid until after the connection is cleaned up + * + * @param t ssl transport + * @param[in] data The der data + * @param[in] len The length + */ +void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *data, int len); + +/** + * @brief Enable global CA store for SSL connection + * + * @param t ssl transport + */ +void esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t); + /** * @brief Set SSL client certificate data for mutual authentication (as PEM format). * Note that, this function stores the pointer to data, rather than making a copy. @@ -51,6 +70,17 @@ void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, */ void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char *data, int len); +/** + * @brief Set SSL client certificate data for mutual authentication (as DER format). + * Note that, this function stores the pointer to data, rather than making a copy. + * So this data must remain valid until after the connection is cleaned up + * + * @param t ssl transport + * @param[in] data The der data + * @param[in] len The length + */ +void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const char *data, int len); + /** * @brief Set SSL client key data for mutual authentication (as PEM format). * Note that, this function stores the pointer to data, rather than making a copy. @@ -62,6 +92,40 @@ void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char */ void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len); +/** + * @brief Set SSL client key data for mutual authentication (as DER format). + * Note that, this function stores the pointer to data, rather than making a copy. + * So this data must remain valid until after the connection is cleaned up + * + * @param t ssl transport + * @param[in] data The der data + * @param[in] len The length + */ +void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len); + +/** + * @brief Skip validation of certificate's common name field + * + * @note Skipping CN validation is not recommended + * + * @param t ssl transport + */ +void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t); + +/** + * @brief Set PSK key and hint for PSK server/client verification in esp-tls component. + * Important notes: + * - This function stores the pointer to data, rather than making a copy. + * So this data must remain valid until after the connection is cleaned up + * - ESP_TLS_PSK_VERIFICATION config option must be enabled in menuconfig + * - certificate verification takes priority so it must not be configured + * to enable PSK method. + * + * @param t ssl transport + * @param[in] psk_hint_key psk key and hint structure defined in esp_tls.h + */ +void esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t, const psk_hint_key_t* psk_hint_key); + #ifdef __cplusplus } #endif diff --git a/components/tcp_transport/include/esp_transport_tcp.h b/components/tcp_transport/include/esp_transport_tcp.h index 57ad4533..7a283fe9 100644 --- a/components/tcp_transport/include/esp_transport_tcp.h +++ b/components/tcp_transport/include/esp_transport_tcp.h @@ -26,7 +26,7 @@ extern "C" { * * @return the allocated esp_transport_handle_t, or NULL if the handle can not be allocated */ -esp_transport_handle_t esp_transport_tcp_init(); +esp_transport_handle_t esp_transport_tcp_init(void); #ifdef __cplusplus diff --git a/components/tcp_transport/include/esp_transport_utils.h b/components/tcp_transport/include/esp_transport_utils.h index 405b4f6b..6a9d1d02 100644 --- a/components/tcp_transport/include/esp_transport_utils.h +++ b/components/tcp_transport/include/esp_transport_utils.h @@ -20,6 +20,15 @@ extern "C" { #endif +/** + * @brief Utility macro to be used for NULL ptr check after malloc + * + */ +#define ESP_TRANSPORT_MEM_CHECK(TAG, a, action) if (!(a)) { \ + ESP_LOGE(TAG,"%s:%d (%s): %s", __FILE__, __LINE__, __FUNCTION__, "Memory exhausted"); \ + action; \ + } + /** * @brief Convert milliseconds to timeval struct * @@ -29,11 +38,6 @@ extern "C" { void esp_transport_utils_ms_to_timeval(int timeout_ms, struct timeval *tv); -#define ESP_TRANSPORT_MEM_CHECK(TAG, a, action) if (!(a)) { \ - ESP_LOGE(TAG,"%s:%d (%s): %s", __FILE__, __LINE__, __FUNCTION__, "Memory exhausted"); \ - action; \ - } - #ifdef __cplusplus } #endif diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index 582c5c7d..0876480a 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -13,6 +13,13 @@ extern "C" { #endif +typedef enum ws_transport_opcodes { + WS_TRANSPORT_OPCODES_TEXT = 0x01, + WS_TRANSPORT_OPCODES_BINARY = 0x02, + WS_TRANSPORT_OPCODES_CLOSE = 0x08, + WS_TRANSPORT_OPCODES_PING = 0x09, + WS_TRANSPORT_OPCODES_PONG = 0x0a, +} ws_transport_opcodes_t; /** * @brief Create web socket transport @@ -23,8 +30,56 @@ extern "C" { */ esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handle); +/** + * @brief Set HTTP path to update protocol to websocket + * + * @param t websocket transport handle + * @param path The HTTP Path + */ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path); +/** + * @brief Set websocket sub protocol header + * + * @param t websocket transport handle + * @param sub_protocol Sub protocol string + * + * @return + * - ESP_OK on success + * - One of the error codes + */ +esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char *sub_protocol); + +/** + * @brief Sends websocket raw message with custom opcode and payload + * + * Note that generic esp_transport_write for ws handle sends + * binary massages by default if size is > 0 and + * ping message if message size is set to 0. + * This API is provided to support explicit messages with arbitrary opcode, + * should it be PING, PONG or TEXT message with arbitrary data. + * + * @param[in] t Websocket transport handle + * @param[in] opcode ws operation code + * @param[in] buffer The buffer + * @param[in] len The length + * @param[in] timeout_ms The timeout milliseconds + * + * @return + * - Number of bytes was written + * - (-1) if there are any errors, should check errno + */ +int esp_transport_ws_send_raw(esp_transport_handle_t t, ws_transport_opcodes_t opcode, const char *b, int len, int timeout_ms); + +/** + * @brief Returns websocket op-code for last received data + * + * @param t websocket transport handle + * + * @return + * - Received op-code as enum + */ +ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t); #ifdef __cplusplus diff --git a/components/tcp_transport/private_include/esp_transport_ssl_internal.h b/components/tcp_transport/private_include/esp_transport_ssl_internal.h new file mode 100644 index 00000000..8b794e05 --- /dev/null +++ b/components/tcp_transport/private_include/esp_transport_ssl_internal.h @@ -0,0 +1,30 @@ +// Copyright 2015-2019 Espressif Systems (Shanghai) PTE LTD +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _ESP_TRANSPORT_INTERNAL_H_ +#define _ESP_TRANSPORT_INTERNAL_H_ + +/** + * @brief Sets error to common transport handle + * + * Note: This function copies the supplied error handle object to tcp_transport's internal + * error handle object + * + * @param[in] A transport handle + * + */ +void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_handle_t error_handle); + + +#endif /* _ESP_TRANSPORT_INTERNAL_H_ */ diff --git a/components/tcp_transport/test/CMakeLists.txt b/components/tcp_transport/test/CMakeLists.txt new file mode 100644 index 00000000..88504a27 --- /dev/null +++ b/components/tcp_transport/test/CMakeLists.txt @@ -0,0 +1,5 @@ +set(COMPONENT_SRCDIRS ".") +set(COMPONENT_PRIV_INCLUDEDIRS "../private_include" ".") +set(COMPONENT_PRIV_REQUIRES unity test_utils tcp_transport) + +register_component() \ No newline at end of file diff --git a/components/tcp_transport/test/component.mk b/components/tcp_transport/test/component.mk new file mode 100644 index 00000000..22e49edd --- /dev/null +++ b/components/tcp_transport/test/component.mk @@ -0,0 +1,5 @@ +# +#Component Makefile +# +COMPONENT_PRIV_INCLUDEDIRS := ../private_include . +COMPONENT_ADD_LDFLAGS = -Wl,--whole-archive -l$(COMPONENT_NAME) -Wl,--no-whole-archive \ No newline at end of file diff --git a/components/tcp_transport/test/test_transport.c b/components/tcp_transport/test/test_transport.c new file mode 100644 index 00000000..28702fb0 --- /dev/null +++ b/components/tcp_transport/test/test_transport.c @@ -0,0 +1,28 @@ +#include "unity.h" + +#include "esp_transport.h" +#include "esp_transport_tcp.h" +#include "esp_transport_ssl.h" +#include "esp_transport_ws.h" + +TEST_CASE("tcp_transport: init and deinit transport list", "[tcp_transport][leaks=0]") +{ + esp_transport_list_handle_t transport_list = esp_transport_list_init(); + esp_transport_handle_t tcp = esp_transport_tcp_init(); + esp_transport_list_add(transport_list, tcp, "tcp"); + TEST_ASSERT_EQUAL(ESP_OK, esp_transport_list_destroy(transport_list)); +} + +TEST_CASE("tcp_transport: using ssl transport separately", "[tcp_transport][leaks=0]") +{ + esp_transport_handle_t h = esp_transport_ssl_init(); + TEST_ASSERT_EQUAL(ESP_OK, esp_transport_destroy(h)); +} + +TEST_CASE("tcp_transport: using ws transport separately", "[tcp_transport][leaks=0]") +{ + esp_transport_handle_t tcp = esp_transport_tcp_init(); + esp_transport_handle_t ws = esp_transport_ws_init(tcp); + TEST_ASSERT_EQUAL(ESP_OK, esp_transport_destroy(ws)); + TEST_ASSERT_EQUAL(ESP_OK, esp_transport_destroy(tcp)); +} diff --git a/components/tcp_transport/transport.c b/components/tcp_transport/transport.c index f6fa311a..d5bc57bb 100644 --- a/components/tcp_transport/transport.c +++ b/components/tcp_transport/transport.c @@ -15,8 +15,9 @@ #include #include +#include -#include "rom/queue.h" +#include "sys/queue.h" #include "esp_log.h" #include "esp_transport.h" @@ -41,7 +42,8 @@ struct esp_transport_item_t { poll_func _poll_write; /*!< Poll and write */ trans_func _destroy; /*!< Destroy and free transport */ connect_async_func _connect_async; /*!< non-blocking connect function of this transport */ - payload_transfer_func _parent_transfer; /*!< Function returning underlying transport layer */ + payload_transfer_func _parent_transfer; /*!< Function returning underlying transport layer */ + esp_tls_error_handle_t error_handle; /*!< Pointer to esp-tls error handle */ STAILQ_ENTRY(esp_transport_item_t) next; }; @@ -52,6 +54,14 @@ struct esp_transport_item_t { */ STAILQ_HEAD(esp_transport_list_t, esp_transport_item_t); +/** + * Internal transport structure holding list of transports and other data common to all transports + */ +typedef struct esp_transport_internal { + struct esp_transport_list_t list; /*!< List of transports */ + esp_tls_error_handle_t error_handle; /*!< Pointer to the error tracker if enabled */ +} esp_transport_internal_t; + static esp_transport_handle_t esp_transport_get_default_parent(esp_transport_handle_t t) { /* @@ -60,36 +70,39 @@ static esp_transport_handle_t esp_transport_get_default_parent(esp_transport_han return t; } -esp_transport_list_handle_t esp_transport_list_init() +esp_transport_list_handle_t esp_transport_list_init(void) { - esp_transport_list_handle_t list = calloc(1, sizeof(struct esp_transport_list_t)); - ESP_TRANSPORT_MEM_CHECK(TAG, list, return NULL); - STAILQ_INIT(list); - return list; + esp_transport_list_handle_t transport = calloc(1, sizeof(esp_transport_internal_t)); + ESP_TRANSPORT_MEM_CHECK(TAG, transport, return NULL); + STAILQ_INIT(&transport->list); + transport->error_handle = calloc(1, sizeof(esp_tls_last_error_t)); + return transport; } -esp_err_t esp_transport_list_add(esp_transport_list_handle_t list, esp_transport_handle_t t, const char *scheme) +esp_err_t esp_transport_list_add(esp_transport_list_handle_t h, esp_transport_handle_t t, const char *scheme) { - if (list == NULL || t == NULL) { + if (h == NULL || t == NULL) { return ESP_ERR_INVALID_ARG; } t->scheme = calloc(1, strlen(scheme) + 1); ESP_TRANSPORT_MEM_CHECK(TAG, t->scheme, return ESP_ERR_NO_MEM); strcpy(t->scheme, scheme); - STAILQ_INSERT_TAIL(list, t, next); + STAILQ_INSERT_TAIL(&h->list, t, next); + // Each transport in a list to share the same error tracker + t->error_handle = h->error_handle; return ESP_OK; } -esp_transport_handle_t esp_transport_list_get_transport(esp_transport_list_handle_t list, const char *scheme) +esp_transport_handle_t esp_transport_list_get_transport(esp_transport_list_handle_t h, const char *scheme) { - if (!list) { + if (!h) { return NULL; } if (scheme == NULL) { - return STAILQ_FIRST(list); + return STAILQ_FIRST(&h->list); } esp_transport_handle_t item; - STAILQ_FOREACH(item, list, next) { + STAILQ_FOREACH(item, &h->list, next) { if (strcasecmp(item->scheme, scheme) == 0) { return item; } @@ -97,30 +110,28 @@ esp_transport_handle_t esp_transport_list_get_transport(esp_transport_list_handl return NULL; } -esp_err_t esp_transport_list_destroy(esp_transport_list_handle_t list) +esp_err_t esp_transport_list_destroy(esp_transport_list_handle_t h) { - esp_transport_list_clean(list); - free(list); + esp_transport_list_clean(h); + free(h->error_handle); + free(h); return ESP_OK; } -esp_err_t esp_transport_list_clean(esp_transport_list_handle_t list) +esp_err_t esp_transport_list_clean(esp_transport_list_handle_t h) { - esp_transport_handle_t item = STAILQ_FIRST(list); + esp_transport_handle_t item = STAILQ_FIRST(&h->list); esp_transport_handle_t tmp; while (item != NULL) { tmp = STAILQ_NEXT(item, next); - if (item->_destroy) { - item->_destroy(item); - } esp_transport_destroy(item); item = tmp; } - STAILQ_INIT(list); + STAILQ_INIT(&h->list); return ESP_OK; } -esp_transport_handle_t esp_transport_init() +esp_transport_handle_t esp_transport_init(void) { esp_transport_handle_t t = calloc(1, sizeof(struct esp_transport_item_t)); ESP_TRANSPORT_MEM_CHECK(TAG, t, return NULL); @@ -137,6 +148,9 @@ esp_transport_handle_t esp_transport_get_payload_transport_handle(esp_transport_ esp_err_t esp_transport_destroy(esp_transport_handle_t t) { + if (t->_destroy) { + t->_destroy(t); + } if (t->scheme) { free(t->scheme); } @@ -277,3 +291,18 @@ esp_err_t esp_transport_set_parent_transport_func(esp_transport_handle_t t, payl t->_parent_transfer = _parent_transport; return ESP_OK; } + +esp_tls_error_handle_t esp_transport_get_error_handle(esp_transport_handle_t t) +{ + if (t) { + return t->error_handle; + } + return NULL; +} + +void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_handle_t error_handle) +{ + if (t) { + memcpy(t->error_handle, error_handle, sizeof(esp_tls_last_error_t)); + } +} \ No newline at end of file diff --git a/components/tcp_transport/transport_ssl.c b/components/tcp_transport/transport_ssl.c index 17a8426e..b92c2115 100644 --- a/components/tcp_transport/transport_ssl.c +++ b/components/tcp_transport/transport_ssl.c @@ -24,6 +24,7 @@ #include "esp_transport.h" #include "esp_transport_ssl.h" #include "esp_transport_utils.h" +#include "esp_transport_ssl_internal.h" static const char *TAG = "TRANS_SSL"; @@ -51,7 +52,7 @@ static int ssl_connect_async(esp_transport_handle_t t, const char *host, int por ssl->cfg.timeout_ms = timeout_ms; ssl->cfg.non_block = true; ssl->ssl_initialized = true; - ssl->tls = calloc(1, sizeof(esp_tls_t)); + ssl->tls = esp_tls_init(); if (!ssl->tls) { return -1; } @@ -69,35 +70,63 @@ static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int ssl->cfg.timeout_ms = timeout_ms; ssl->ssl_initialized = true; - ssl->tls = esp_tls_conn_new(host, strlen(host), port, &ssl->cfg); - if (!ssl->tls) { + ssl->tls = esp_tls_init(); + if (esp_tls_conn_new_sync(host, strlen(host), port, &ssl->cfg, ssl->tls) < 0) { ESP_LOGE(TAG, "Failed to open a new connection"); + esp_transport_set_errors(t, ssl->tls->error_handle); + esp_tls_conn_delete(ssl->tls); + ssl->tls = NULL; return -1; } + return 0; } static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms) { transport_ssl_t *ssl = esp_transport_get_context_data(t); + int ret = -1; fd_set readset; + fd_set errset; FD_ZERO(&readset); + FD_ZERO(&errset); FD_SET(ssl->tls->sockfd, &readset); + FD_SET(ssl->tls->sockfd, &errset); struct timeval timeout; esp_transport_utils_ms_to_timeval(timeout_ms, &timeout); - return select(ssl->tls->sockfd + 1, &readset, NULL, NULL, &timeout); + ret = select(ssl->tls->sockfd + 1, &readset, NULL, &errset, &timeout); + if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) { + int sock_errno = 0; + uint32_t optlen = sizeof(sock_errno); + getsockopt(ssl->tls->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen); + ESP_LOGE(TAG, "ssl_poll_read select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd); + ret = -1; + } + return ret; } static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms) { transport_ssl_t *ssl = esp_transport_get_context_data(t); + int ret = -1; fd_set writeset; + fd_set errset; FD_ZERO(&writeset); + FD_ZERO(&errset); FD_SET(ssl->tls->sockfd, &writeset); + FD_SET(ssl->tls->sockfd, &errset); struct timeval timeout; esp_transport_utils_ms_to_timeval(timeout_ms, &timeout); - return select(ssl->tls->sockfd + 1, NULL, &writeset, NULL, &timeout); + ret = select(ssl->tls->sockfd + 1, NULL, &writeset, &errset, &timeout); + if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) { + int sock_errno = 0; + uint32_t optlen = sizeof(sock_errno); + getsockopt(ssl->tls->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen); + ESP_LOGE(TAG, "ssl_poll_write select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd); + ret = -1; + } + return ret; } static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms) @@ -110,8 +139,9 @@ static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int return poll; } ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len); - if (ret <= 0) { + if (ret < 0) { ESP_LOGE(TAG, "esp_tls_conn_write error, errno=%s", strerror(errno)); + esp_transport_set_errors(t, ssl->tls->error_handle); } return ret; } @@ -127,9 +157,12 @@ static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout } } ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len); - if (ret <= 0) { + if (ret < 0) { ESP_LOGE(TAG, "esp_tls_conn_read error, errno=%s", strerror(errno)); - return -1; + esp_transport_set_errors(t, ssl->tls->error_handle); + } + if (ret == 0) { + ret = -1; } return ret; } @@ -153,6 +186,22 @@ static int ssl_destroy(esp_transport_handle_t t) return 0; } +void esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t) +{ + transport_ssl_t *ssl = esp_transport_get_context_data(t); + if (t && ssl) { + ssl->cfg.use_global_ca_store = true; + } +} + +void esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t, const psk_hint_key_t* psk_hint_key) +{ + transport_ssl_t *ssl = esp_transport_get_context_data(t); + if (t && ssl) { + ssl->cfg.psk_hint_key = psk_hint_key; + } +} + void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, int len) { transport_ssl_t *ssl = esp_transport_get_context_data(t); @@ -162,6 +211,15 @@ void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, } } +void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *data, int len) +{ + transport_ssl_t *ssl = esp_transport_get_context_data(t); + if (t && ssl) { + ssl->cfg.cacert_buf = (void *)data; + ssl->cfg.cacert_bytes = len; + } +} + void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char *data, int len) { transport_ssl_t *ssl = esp_transport_get_context_data(t); @@ -171,6 +229,15 @@ void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char } } +void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const char *data, int len) +{ + transport_ssl_t *ssl = esp_transport_get_context_data(t); + if (t && ssl) { + ssl->cfg.clientcert_buf = (void *)data; + ssl->cfg.clientcert_bytes = len; + } +} + void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len) { transport_ssl_t *ssl = esp_transport_get_context_data(t); @@ -180,7 +247,24 @@ void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char } } -esp_transport_handle_t esp_transport_ssl_init() +void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len) +{ + transport_ssl_t *ssl = esp_transport_get_context_data(t); + if (t && ssl) { + ssl->cfg.clientkey_buf = (void *)data; + ssl->cfg.clientkey_bytes = len; + } +} + +void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t) +{ + transport_ssl_t *ssl = esp_transport_get_context_data(t); + if (t && ssl) { + ssl->cfg.skip_common_name = true; + } +} + +esp_transport_handle_t esp_transport_ssl_init(void) { esp_transport_handle_t t = esp_transport_init(); transport_ssl_t *ssl = calloc(1, sizeof(transport_ssl_t)); diff --git a/components/tcp_transport/transport_tcp.c b/components/tcp_transport/transport_tcp.c index 34d85596..3fba399a 100644 --- a/components/tcp_transport/transport_tcp.c +++ b/components/tcp_transport/transport_tcp.c @@ -15,7 +15,7 @@ #include #include -#include +#include "lwip/sockets.h" #include "lwip/dns.h" #include "lwip/netdb.h" @@ -77,6 +77,7 @@ static int tcp_connect(esp_transport_handle_t t, const char *host, int port, int esp_transport_utils_ms_to_timeval(timeout_ms, &tv); setsockopt(tcp->sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + setsockopt(tcp->sock, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); ESP_LOGD(TAG, "[sock=%d],connecting to server IP:%s,Port:%d...", tcp->sock, ipaddr_ntoa((const ip_addr_t*)&remote_ip.sin_addr.s_addr), port); @@ -115,23 +116,47 @@ static int tcp_read(esp_transport_handle_t t, char *buffer, int len, int timeout static int tcp_poll_read(esp_transport_handle_t t, int timeout_ms) { transport_tcp_t *tcp = esp_transport_get_context_data(t); + int ret = -1; fd_set readset; + fd_set errset; FD_ZERO(&readset); + FD_ZERO(&errset); FD_SET(tcp->sock, &readset); + FD_SET(tcp->sock, &errset); struct timeval timeout; esp_transport_utils_ms_to_timeval(timeout_ms, &timeout); - return select(tcp->sock + 1, &readset, NULL, NULL, &timeout); + ret = select(tcp->sock + 1, &readset, NULL, &errset, &timeout); + if (ret > 0 && FD_ISSET(tcp->sock, &errset)) { + int sock_errno = 0; + uint32_t optlen = sizeof(sock_errno); + getsockopt(tcp->sock, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen); + ESP_LOGE(TAG, "tcp_poll_read select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), tcp->sock); + ret = -1; + } + return ret; } static int tcp_poll_write(esp_transport_handle_t t, int timeout_ms) { transport_tcp_t *tcp = esp_transport_get_context_data(t); + int ret = -1; fd_set writeset; + fd_set errset; FD_ZERO(&writeset); + FD_ZERO(&errset); FD_SET(tcp->sock, &writeset); + FD_SET(tcp->sock, &errset); struct timeval timeout; esp_transport_utils_ms_to_timeval(timeout_ms, &timeout); - return select(tcp->sock + 1, NULL, &writeset, NULL, &timeout); + ret = select(tcp->sock + 1, NULL, &writeset, &errset, &timeout); + if (ret > 0 && FD_ISSET(tcp->sock, &errset)) { + int sock_errno = 0; + uint32_t optlen = sizeof(sock_errno); + getsockopt(tcp->sock, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen); + ESP_LOGE(TAG, "tcp_poll_write select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), tcp->sock); + ret = -1; + } + return ret; } static int tcp_close(esp_transport_handle_t t) @@ -153,7 +178,7 @@ static esp_err_t tcp_destroy(esp_transport_handle_t t) return 0; } -esp_transport_handle_t esp_transport_tcp_init() +esp_transport_handle_t esp_transport_tcp_init(void) { esp_transport_handle_t t = esp_transport_init(); transport_tcp_t *tcp = calloc(1, sizeof(transport_tcp_t)); diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 86b34689..637d4d61 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -1,10 +1,7 @@ -#include "sdkconfig.h" - -#ifdef CONFIG_SSL_USING_MBEDTLS #include #include #include -#include "esp_libc.h" +#include #include "esp_log.h" #include "esp_transport.h" @@ -27,15 +24,22 @@ static const char *TAG = "TRANSPORT_WS"; #define WS_MASK 0x80 #define WS_SIZE16 126 #define WS_SIZE64 127 -#define MAX_WEBSOCKET_HEADER_SIZE 10 +#define MAX_WEBSOCKET_HEADER_SIZE 16 #define WS_RESPONSE_OK 101 typedef struct { char *path; char *buffer; + char *sub_protocol; + uint8_t read_opcode; esp_transport_handle_t parent; } transport_ws_t; +static inline uint8_t ws_get_bin_opcode(ws_transport_opcodes_t opcode) +{ + return (uint8_t)opcode; +} + static esp_transport_handle_t ws_get_payload_transport_handle(esp_transport_handle_t t) { transport_ws_t *ws = esp_transport_get_context_data(t); @@ -63,10 +67,9 @@ static char *trimwhitespace(const char *str) return (char *)str; } - static char *get_http_header(const char *buffer, const char *key) { - char *found = strstr(buffer, key); + char *found = strcasestr(buffer, key); if (found) { found += strlen(key); char *found_end = strstr(found, "\r\n"); @@ -83,11 +86,12 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int { transport_ws_t *ws = esp_transport_get_context_data(t); if (esp_transport_connect(ws->parent, host, port, timeout_ms) < 0) { - ESP_LOGE(TAG, "Error connect to ther server"); + ESP_LOGE(TAG, "Error connecting to host %s:%d", host, port); + return -1; } unsigned char random_key[16]; - os_get_random(random_key, sizeof(random_key)); + getrandom(random_key, sizeof(random_key), 0); // Size of base64 coded string is equal '((input_size * 4) / 3) + (input_size / 96) + 6' including Z-term unsigned char client_key[28] = {0}; @@ -100,9 +104,8 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int "Host: %s:%d\r\n" "Upgrade: websocket\r\n" "Sec-WebSocket-Version: 13\r\n" - "Sec-WebSocket-Protocol: mqtt\r\n" "Sec-WebSocket-Key: %s\r\n" - "User-Agent: ESP32 Websocket Client\r\n\r\n", + "User-Agent: ESP32 Websocket Client\r\n", ws->path, host, port, client_key); @@ -110,15 +113,38 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int ESP_LOGE(TAG, "Error in request generation, %d", len); return -1; } + if (ws->sub_protocol) { + int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "Sec-WebSocket-Protocol: %s\r\n", ws->sub_protocol); + len += r; + if (r <= 0 || len >= DEFAULT_WS_BUFFER) { + ESP_LOGE(TAG, "Error in request generation" + "(snprintf of subprotocol returned %d, desired request len: %d, buffer size: %d", r, len, DEFAULT_WS_BUFFER); + return -1; + } + } + int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "\r\n"); + len += r; + if (r <= 0 || len >= DEFAULT_WS_BUFFER) { + ESP_LOGE(TAG, "Error in request generation" + "(snprintf of header terminal returned %d, desired request len: %d, buffer size: %d", r, len, DEFAULT_WS_BUFFER); + return -1; + } ESP_LOGD(TAG, "Write upgrate request\r\n%s", ws->buffer); if (esp_transport_write(ws->parent, ws->buffer, len, timeout_ms) <= 0) { ESP_LOGE(TAG, "Error write Upgrade header %s", ws->buffer); return -1; } - if ((len = esp_transport_read(ws->parent, ws->buffer, DEFAULT_WS_BUFFER, timeout_ms)) <= 0) { - ESP_LOGE(TAG, "Error read response for Upgrade header %s", ws->buffer); - return -1; - } + int header_len = 0; + do { + if ((len = esp_transport_read(ws->parent, ws->buffer + header_len, DEFAULT_WS_BUFFER - header_len, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error read response for Upgrade header %s", ws->buffer); + return -1; + } + header_len += len; + ws->buffer[header_len] = '\0'; + ESP_LOGD(TAG, "Read header chunk %d, current header size: %d", len, header_len); + } while (NULL == strstr(ws->buffer, "\r\n\r\n") && header_len < DEFAULT_WS_BUFFER); + char *server_key = get_http_header(ws->buffer, "Sec-WebSocket-Accept:"); if (server_key == NULL) { ESP_LOGE(TAG, "Sec-WebSocket-Accept not found"); @@ -147,70 +173,131 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int return 0; } -static int ws_write(esp_transport_handle_t t, const char *buff, int len, int timeout_ms) +static int _ws_write(esp_transport_handle_t t, int opcode, int mask_flag, const char *b, int len, int timeout_ms) { transport_ws_t *ws = esp_transport_get_context_data(t); + char *buffer = (char *)b; char ws_header[MAX_WEBSOCKET_HEADER_SIZE]; char *mask; int header_len = 0, i; - char *buffer = (char *)buff; + int poll_write; if ((poll_write = esp_transport_poll_write(ws->parent, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error transport_poll_write"); return poll_write; } + ws_header[header_len++] = opcode; - ws_header[header_len++] = WS_OPCODE_BINARY | WS_FIN; - - // NOTE: no support for > 16-bit sized messages - if (len > 125) { - ws_header[header_len++] = WS_SIZE16 | WS_MASK; + if (len <= 125) { + ws_header[header_len++] = (uint8_t)(len | mask_flag); + } else if (len < 65536) { + ws_header[header_len++] = WS_SIZE16 | mask_flag; ws_header[header_len++] = (uint8_t)(len >> 8); ws_header[header_len++] = (uint8_t)(len & 0xFF); } else { - ws_header[header_len++] = (uint8_t)(len | WS_MASK); + ws_header[header_len++] = WS_SIZE64 | mask_flag; + /* Support maximum 4 bytes length */ + ws_header[header_len++] = 0; //(uint8_t)((len >> 56) & 0xFF); + ws_header[header_len++] = 0; //(uint8_t)((len >> 48) & 0xFF); + ws_header[header_len++] = 0; //(uint8_t)((len >> 40) & 0xFF); + ws_header[header_len++] = 0; //(uint8_t)((len >> 32) & 0xFF); + ws_header[header_len++] = (uint8_t)((len >> 24) & 0xFF); + ws_header[header_len++] = (uint8_t)((len >> 16) & 0xFF); + ws_header[header_len++] = (uint8_t)((len >> 8) & 0xFF); + ws_header[header_len++] = (uint8_t)((len >> 0) & 0xFF); } - mask = &ws_header[header_len]; - os_get_random((unsigned char *)ws_header + header_len, 4); - header_len += 4; - for (i = 0; i < len; ++i) { - buffer[i] = (buffer[i] ^ mask[i % 4]); + if (mask_flag) { + mask = &ws_header[header_len]; + getrandom(ws_header + header_len, 4, 0); + header_len += 4; + + for (i = 0; i < len; ++i) { + buffer[i] = (buffer[i] ^ mask[i % 4]); + } } + if (esp_transport_write(ws->parent, ws_header, header_len, timeout_ms) != header_len) { ESP_LOGE(TAG, "Error write header"); return -1; } - return esp_transport_write(ws->parent, buffer, len, timeout_ms); + if (len == 0) { + return 0; + } + + int ret = esp_transport_write(ws->parent, buffer, len, timeout_ms); + // in case of masked transport we have to revert back to the original data, as ws layer + // does not create its own copy of data to be sent + if (mask_flag) { + mask = &ws_header[header_len-4]; + for (i = 0; i < len; ++i) { + buffer[i] = (buffer[i] ^ mask[i % 4]); + } + } + return ret; +} + +int esp_transport_ws_send_raw(esp_transport_handle_t t, ws_transport_opcodes_t opcode, const char *b, int len, int timeout_ms) +{ + uint8_t op_code = ws_get_bin_opcode(opcode); + if (t == NULL) { + ESP_LOGE(TAG, "Transport must be a valid ws handle"); + return ESP_ERR_INVALID_ARG; + } + ESP_LOGD(TAG, "Sending raw ws message with opcode %d", op_code); + return _ws_write(t, op_code | WS_FIN, WS_MASK, b, len, timeout_ms); +} + +static int ws_write(esp_transport_handle_t t, const char *b, int len, int timeout_ms) +{ + if (len == 0) { + // Default transport write of zero length in ws layer sends out a ping message. + // This behaviour could however be altered in IDF 5.0, since a separate API for sending + // messages with user defined opcodes has been introduced. + ESP_LOGD(TAG, "Write PING message"); + return _ws_write(t, WS_OPCODE_PING | WS_FIN, WS_MASK, NULL, 0, timeout_ms); + } + return _ws_write(t, WS_OPCODE_BINARY | WS_FIN, WS_MASK, b, len, timeout_ms); } static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) { transport_ws_t *ws = esp_transport_get_context_data(t); int payload_len; - int payload_len_buff = len; - char *data_ptr = buffer, opcode, mask, *mask_key = NULL; + char ws_header[MAX_WEBSOCKET_HEADER_SIZE]; + char *data_ptr = ws_header, mask, *mask_key = NULL; int rlen; int poll_read; if ((poll_read = esp_transport_poll_read(ws->parent, timeout_ms)) <= 0) { return poll_read; } - if ((rlen = esp_transport_read(ws->parent, buffer, len, timeout_ms)) <= 0) { + + // Receive and process header first (based on header size) + int header = 2; + if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } - opcode = (*data_ptr & 0x0F); + ws->read_opcode = (*data_ptr & 0x0F); data_ptr ++; mask = ((*data_ptr >> 7) & 0x01); payload_len = (*data_ptr & 0x7F); data_ptr++; - ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", opcode, mask, payload_len); + ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", ws->read_opcode, mask, payload_len); if (payload_len == 126) { // headerLen += 2; + if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error read data"); + return rlen; + } payload_len = data_ptr[0] << 8 | data_ptr[1]; - payload_len_buff = len - 4; - data_ptr += 2; } else if (payload_len == 127) { // headerLen += 8; + header = 8; + if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error read data"); + return rlen; + } if (data_ptr[0] != 0 || data_ptr[1] != 0 || data_ptr[2] != 0 || data_ptr[3] != 0) { // really too big! @@ -218,22 +305,25 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ } else { payload_len = data_ptr[4] << 24 | data_ptr[5] << 16 | data_ptr[6] << 8 | data_ptr[7]; } - data_ptr += 8; - payload_len_buff = len - 10; } - if (payload_len > payload_len_buff) { - ESP_LOGD(TAG, "Actual data received (%d) are longer than mqtt buffer (%d)", payload_len, payload_len_buff); - payload_len = payload_len_buff; + + if (payload_len > len) { + ESP_LOGD(TAG, "Actual data to receive (%d) are longer than ws buffer (%d)", payload_len, len); + payload_len = len; + } + + // Then receive and process payload + if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, payload_len, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error read data"); + return rlen; } if (mask) { - mask_key = data_ptr; - data_ptr += 4; + mask_key = buffer; + data_ptr = buffer + 4; for (int i = 0; i < payload_len; i++) { buffer[i] = (data_ptr[i] ^ mask_key[i % 4]); } - } else { - memmove(buffer, data_ptr, payload_len); } return payload_len; } @@ -261,6 +351,7 @@ static esp_err_t ws_destroy(esp_transport_handle_t t) transport_ws_t *ws = esp_transport_get_context_data(t); free(ws->buffer); free(ws->path); + free(ws->sub_protocol); free(ws); return 0; } @@ -270,6 +361,7 @@ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path) ws->path = realloc(ws->path, strlen(path) + 1); strcpy(ws->path, path); } + esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handle) { esp_transport_handle_t t = esp_transport_init(); @@ -278,7 +370,10 @@ esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handl ws->parent = parent_handle; ws->path = strdup("/"); - ESP_TRANSPORT_MEM_CHECK(TAG, ws->path, return NULL); + ESP_TRANSPORT_MEM_CHECK(TAG, ws->path, { + free(ws); + return NULL; + }); ws->buffer = malloc(DEFAULT_WS_BUFFER); ESP_TRANSPORT_MEM_CHECK(TAG, ws->buffer, { free(ws->path); @@ -293,4 +388,29 @@ esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handl esp_transport_set_context_data(t, ws); return t; } -#endif + +esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char *sub_protocol) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + transport_ws_t *ws = esp_transport_get_context_data(t); + if (ws->sub_protocol) { + free(ws->sub_protocol); + } + if (sub_protocol == NULL) { + ws->sub_protocol = NULL; + return ESP_OK; + } + ws->sub_protocol = strdup(sub_protocol); + if (ws->sub_protocol == NULL) { + return ESP_ERR_NO_MEM; + } + return ESP_OK; +} + +ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t) +{ + transport_ws_t *ws = esp_transport_get_context_data(t); + return ws->read_opcode; +}