From 0cd33eaea1fb5248013a95487faccb60358e9c09 Mon Sep 17 00:00:00 2001 From: yuanjm Date: Fri, 27 Mar 2020 14:23:14 +0800 Subject: [PATCH] feat(tls): update esp-tls and tcp_transport from idf Commit ID:88bf21b2 --- components/esp-tls/Kconfig | 18 +- components/esp-tls/esp_tls.c | 60 ++- components/esp-tls/esp_tls.h | 141 +++--- components/esp-tls/esp_tls_mbedtls.c | 31 +- components/esp-tls/esp_tls_wolfssl.c | 401 ++++++++++++++++-- .../esp-tls/private_include/esp_tls_wolfssl.h | 14 + .../esp_http_client/lib/include/http_utils.h | 8 +- .../tcp_transport/include/esp_transport.h | 12 +- .../tcp_transport/include/esp_transport_ssl.h | 20 + .../tcp_transport/include/esp_transport_ws.h | 37 +- .../esp_transport_utils.h | 13 +- components/tcp_transport/test/CMakeLists.txt | 8 +- components/tcp_transport/transport_ssl.c | 29 +- components/tcp_transport/transport_tcp.c | 16 +- components/tcp_transport/transport_utils.c | 6 +- components/tcp_transport/transport_ws.c | 187 ++++++-- 16 files changed, 810 insertions(+), 191 deletions(-) rename components/tcp_transport/{include => private_include}/esp_transport_utils.h (65%) diff --git a/components/esp-tls/Kconfig b/components/esp-tls/Kconfig index 0bfe9a6e..dc915417 100644 --- a/components/esp-tls/Kconfig +++ b/components/esp-tls/Kconfig @@ -14,22 +14,22 @@ menu "ESP-TLS" config ESP_TLS_SERVER bool "Enable ESP-TLS Server" - depends on ESP_TLS_USING_MBEDTLS default n help - Enable support for creating server side SSL/TLS session, uses the mbedtls crypto library + Enable support for creating server side SSL/TLS session, available for mbedTLS + as well as wolfSSL TLS library. config ESP_TLS_PSK_VERIFICATION bool "Enable PSK verification" - depends on ESP_TLS_USING_MBEDTLS - select MBEDTLS_PSK_MODES - select MBEDTLS_KEY_EXCHANGE_PSK - select MBEDTLS_KEY_EXCHANGE_DHE_PSK - select MBEDTLS_KEY_EXCHANGE_ECDHE_PSK - select MBEDTLS_KEY_EXCHANGE_RSA_PSK + select MBEDTLS_PSK_MODES if ESP_TLS_USING_MBEDTLS + select MBEDTLS_KEY_EXCHANGE_PSK if ESP_TLS_USING_MBEDTLS + select MBEDTLS_KEY_EXCHANGE_DHE_PSK if ESP_TLS_USING_MBEDTLS + select MBEDTLS_KEY_EXCHANGE_ECDHE_PSK if ESP_TLS_USING_MBEDTLS + select MBEDTLS_KEY_EXCHANGE_RSA_PSK if ESP_TLS_USING_MBEDTLS default n help - Enable support for pre shared key ciphers, uses the mbedtls crypto library + Enable support for pre shared key ciphers, supported for both mbedTLS as well as + wolfSSL TLS library. config ESP_WOLFSSL_SMALL_CERT_VERIFY bool "Enable SMALL_CERT_VERIFY" diff --git a/components/esp-tls/esp_tls.c b/components/esp-tls/esp_tls.c index cd45d218..da4dfe29 100644 --- a/components/esp-tls/esp_tls.c +++ b/components/esp-tls/esp_tls.c @@ -60,6 +60,10 @@ static const char *TAG = "esp-tls"; #define _esp_tls_read esp_wolfssl_read #define _esp_tls_write esp_wolfssl_write #define _esp_tls_conn_delete esp_wolfssl_conn_delete +#ifdef CONFIG_ESP_TLS_SERVER +#define _esp_tls_server_session_create esp_wolfssl_server_session_create +#define _esp_tls_server_session_delete esp_wolfssl_server_session_delete +#endif /* CONFIG_ESP_TLS_SERVER */ #define _esp_tls_get_bytes_avail esp_wolfssl_get_bytes_avail #define _esp_tls_init_global_ca_store esp_wolfssl_init_global_ca_store #define _esp_tls_set_global_ca_store esp_wolfssl_set_global_ca_store /*!< Callback function for setting global CA store data for TLS/SSL */ @@ -115,8 +119,9 @@ esp_tls_t *esp_tls_init(void) return NULL; } #ifdef CONFIG_ESP_TLS_USING_MBEDTLS - tls->server_fd.fd = tls->sockfd = -1; + tls->server_fd.fd = -1; #endif + tls->sockfd = -1; return tls; } @@ -193,7 +198,11 @@ static esp_err_t esp_tcp_connect(const char *host, int hostlen, int port, int *s } if (cfg->non_block) { int flags = fcntl(fd, F_GETFL, 0); - fcntl(fd, F_SETFL, flags | O_NONBLOCK); + ret = fcntl(fd, F_SETFL, flags | O_NONBLOCK); + if (ret < 0) { + ESP_LOGE(TAG, "Failed to configure the socket as non-blocking (errno %d)", errno); + goto err_freesocket; + } } } @@ -240,8 +249,8 @@ static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, c return -1; } if (!cfg) { - tls->_read = tcp_read; - tls->_write = tcp_write; + tls->read = tcp_read; + tls->write = tcp_write; ESP_LOGD(TAG, "non-tls connection established"); return 1; } @@ -259,9 +268,9 @@ static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, c ms_to_timeval(cfg->timeout_ms, &tv); /* In case of non-blocking I/O, we use the select() API to check whether - connection has been estbalished or not*/ + connection has been established or not*/ if (select(tls->sockfd + 1, &tls->rset, &tls->wset, NULL, - cfg->timeout_ms ? &tv : NULL) == 0) { + cfg->timeout_ms>0 ? &tv : NULL) == 0) { ESP_LOGD(TAG, "select() timed out"); return 0; } @@ -286,8 +295,8 @@ static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, c tls->conn_state = ESP_TLS_FAIL; return -1; } - tls->_read = _esp_tls_read; - tls->_write = _esp_tls_write; + tls->read = _esp_tls_read; + tls->write = _esp_tls_write; tls->conn_state = ESP_TLS_HANDSHAKE; /* falls through */ case ESP_TLS_HANDSHAKE: @@ -309,12 +318,13 @@ static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, c */ esp_tls_t *esp_tls_conn_new(const char *hostname, int hostlen, int port, const esp_tls_cfg_t *cfg) { - esp_tls_t *tls = (esp_tls_t *)calloc(1, sizeof(esp_tls_t)); + esp_tls_t *tls = esp_tls_init(); if (!tls) { return NULL; } /* esp_tls_conn_new() API establishes connection in a blocking manner thus this loop ensures that esp_tls_conn_new() API returns only after connection is established unless there is an error*/ + size_t start = xTaskGetTickCount(); while (1) { int ret = esp_tls_low_level_conn(hostname, hostlen, port, cfg, tls); if (ret == 1) { @@ -323,6 +333,14 @@ esp_tls_t *esp_tls_conn_new(const char *hostname, int hostlen, int port, const e esp_tls_conn_delete(tls); ESP_LOGE(TAG, "Failed to open new connection"); return NULL; + } else if (ret == 0 && cfg->timeout_ms >= 0) { + size_t timeout_ticks = pdMS_TO_TICKS(cfg->timeout_ms); + uint32_t expired = xTaskGetTickCount() - start; + if (expired >= timeout_ticks) { + esp_tls_conn_delete(tls); + ESP_LOGE(TAG, "Failed to open new connection in specified timeout"); + return NULL; + } } } return NULL; @@ -330,8 +348,9 @@ esp_tls_t *esp_tls_conn_new(const char *hostname, int hostlen, int port, const e int esp_tls_conn_new_sync(const char *hostname, int hostlen, int port, const esp_tls_cfg_t *cfg, esp_tls_t *tls) { - /* esp_tls_conn_new_sync() is a sync alternative to esp_tls_conn_new_async() with symetric function prototype + /* esp_tls_conn_new_sync() is a sync alternative to esp_tls_conn_new_async() with symmetric function prototype it is an alternative to esp_tls_conn_new() which is left for compatibility reasons */ + size_t start = xTaskGetTickCount(); while (1) { int ret = esp_tls_low_level_conn(hostname, hostlen, port, cfg, tls); if (ret == 1) { @@ -339,6 +358,14 @@ int esp_tls_conn_new_sync(const char *hostname, int hostlen, int port, const esp } else if (ret == -1) { ESP_LOGE(TAG, "Failed to open new connection"); return -1; + } else if (ret == 0 && cfg->timeout_ms >= 0) { + size_t timeout_ticks = pdMS_TO_TICKS(cfg->timeout_ms); + uint32_t expired = xTaskGetTickCount() - start; + if (expired >= timeout_ticks) { + ESP_LOGW(TAG, "Failed to open new connection in specified timeout"); + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_ESP, ESP_ERR_ESP_TLS_CONNECTION_TIMEOUT); + return 0; + } } } return 0; @@ -384,6 +411,7 @@ esp_tls_t *esp_tls_conn_http_new(const char *url, const esp_tls_cfg_t *cfg) get_port(url, &u), cfg, tls) == 1) { return tls; } + esp_tls_conn_delete(tls); return NULL; } @@ -409,6 +437,7 @@ mbedtls_x509_crt *esp_tls_get_global_ca_store(void) return _esp_tls_get_global_ca_store(); } +#endif /* CONFIG_ESP_TLS_USING_MBEDTLS */ #ifdef CONFIG_ESP_TLS_SERVER /** * @brief Create a server side TLS/SSL connection @@ -425,13 +454,22 @@ void esp_tls_server_session_delete(esp_tls_t *tls) return _esp_tls_server_session_delete(tls); } #endif /* CONFIG_ESP_TLS_SERVER */ -#endif /* CONFIG_ESP_TLS_USING_MBEDTLS */ ssize_t esp_tls_get_bytes_avail(esp_tls_t *tls) { return _esp_tls_get_bytes_avail(tls); } +esp_err_t esp_tls_get_conn_sockfd(esp_tls_t *tls, int *sockfd) +{ + if (!tls || !sockfd) { + ESP_LOGE(TAG, "Invalid arguments passed"); + return ESP_ERR_INVALID_ARG; + } + *sockfd = tls->sockfd; + return ESP_OK; +} + esp_err_t esp_tls_get_and_clear_last_error(esp_tls_error_handle_t h, int *esp_tls_code, int *esp_tls_flags) { if (!h) { diff --git a/components/esp-tls/esp_tls.h b/components/esp-tls/esp_tls.h index 7e9e0b51..a5983b2e 100644 --- a/components/esp-tls/esp_tls.h +++ b/components/esp-tls/esp_tls.h @@ -54,6 +54,15 @@ extern "C" { #define ESP_ERR_MBEDTLS_PK_PARSE_KEY_FAILED (ESP_ERR_ESP_TLS_BASE + 0x0F) /*!< mbedtls api returned failed */ #define ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED (ESP_ERR_ESP_TLS_BASE + 0x10) /*!< mbedtls api returned failed */ #define ESP_ERR_MBEDTLS_SSL_CONF_PSK_FAILED (ESP_ERR_ESP_TLS_BASE + 0x11) /*!< mbedtls api returned failed */ +#define ESP_ERR_ESP_TLS_CONNECTION_TIMEOUT (ESP_ERR_ESP_TLS_BASE + 0x12) /*!< new connection in esp_tls_low_level_conn connection timeouted */ +#define ESP_ERR_WOLFSSL_SSL_SET_HOSTNAME_FAILED (ESP_ERR_ESP_TLS_BASE + 0x13) /*!< wolfSSL api returned error */ +#define ESP_ERR_WOLFSSL_SSL_CONF_ALPN_PROTOCOLS_FAILED (ESP_ERR_ESP_TLS_BASE + 0x14) /*!< wolfSSL api returned error */ +#define ESP_ERR_WOLFSSL_CERT_VERIFY_SETUP_FAILED (ESP_ERR_ESP_TLS_BASE + 0x15) /*!< wolfSSL api returned error */ +#define ESP_ERR_WOLFSSL_KEY_VERIFY_SETUP_FAILED (ESP_ERR_ESP_TLS_BASE + 0x16) /*!< wolfSSL api returned error */ +#define ESP_ERR_WOLFSSL_SSL_HANDSHAKE_FAILED (ESP_ERR_ESP_TLS_BASE + 0x17) /*!< wolfSSL api returned failed */ +#define ESP_ERR_WOLFSSL_CTX_SETUP_FAILED (ESP_ERR_ESP_TLS_BASE + 0x18) /*!< wolfSSL api returned failed */ +#define ESP_ERR_WOLFSSL_SSL_SETUP_FAILED (ESP_ERR_ESP_TLS_BASE + 0x19) /*!< wolfSSL api returned failed */ +#define ESP_ERR_WOLFSSL_SSL_WRITE_FAILED (ESP_ERR_ESP_TLS_BASE + 0x1A) /*!< wolfSSL api returned failed */ #ifdef CONFIG_ESP_TLS_USING_MBEDTLS #define ESP_TLS_ERR_SSL_WANT_READ MBEDTLS_ERR_SSL_WANT_READ @@ -101,8 +110,8 @@ typedef struct psk_key_hint { } psk_hint_key_t; /** - * @brief ESP-TLS configuration parameters - * + * @brief ESP-TLS configuration parameters + * * @note Note about format of certificates: * - This structure includes certificates of a Certificate Authority, of client or server as well * as private keys, which may be of PEM or DER format. In case of PEM format, the buffer must be @@ -112,17 +121,17 @@ typedef struct psk_key_hint { * - Variables names of certificates and private key buffers and sizes are defined as unions providing * backward compatibility for legacy *_pem_buf and *_pem_bytes names which suggested only PEM format * was supported. It is encouraged to use generic names such as cacert_buf and cacert_bytes. - */ + */ typedef struct esp_tls_cfg { const char **alpn_protos; /*!< Application protocols required for HTTP2. If HTTP2/ALPN support is required, a list - of protocols that should be negotiated. + of protocols that should be negotiated. The format is length followed by protocol - name. + name. For the most common cases the following is ok: const char **alpn_protos = { "h2", NULL }; - where 'h2' is the protocol name */ - + union { const unsigned char *cacert_buf; /*!< Certificate Authority's certificate in a buffer. Format may be PEM or DER, depending on mbedtls-support @@ -170,8 +179,8 @@ typedef struct esp_tls_cfg { unsigned int clientkey_password_len; /*!< String length of the password pointed to by clientkey_password */ - bool non_block; /*!< Configure non-blocking mode. If set to true the - underneath socket will be configured in non + bool non_block; /*!< Configure non-blocking mode. If set to true the + underneath socket will be configured in non blocking mode after tls session is established */ int timeout_ms; /*!< Network timeout in milliseconds */ @@ -188,6 +197,10 @@ typedef struct esp_tls_cfg { then PSK authentication is enabled with configured setup. Important note: the pointer must be valid for connection */ + esp_err_t (*crt_bundle_attach)(void *conf); + /*!< Function pointer to esp_crt_bundle_attach. Enables the use of certification + bundle for server verification, must be enabled in menuconfig */ + } esp_tls_cfg_t; #ifdef CONFIG_ESP_TLS_SERVER @@ -246,46 +259,24 @@ typedef struct esp_tls_cfg_server { #endif /* ! CONFIG_ESP_TLS_SERVER */ /** - * @brief ESP-TLS Connection Handle + * @brief ESP-TLS Connection Handle */ typedef struct esp_tls { - int sockfd; /*!< Underlying socket file descriptor. */ - - ssize_t (*_read)(struct esp_tls *tls, char *data, size_t datalen); /*!< Callback function for reading data from TLS/SSL - connection. */ - - ssize_t (*_write)(struct esp_tls *tls, const char *data, size_t datalen); /*!< Callback function for writing data to TLS/SSL - connection. */ - - esp_tls_conn_state_t conn_state; /*!< ESP-TLS Connection state */ - - fd_set rset; /*!< read file descriptors */ - - fd_set wset; /*!< write file descriptors */ - - bool is_tls; /*!< indicates connection type (TLS or NON-TLS) */ - - esp_tls_role_t role; /*!< esp-tls role - - ESP_TLS_CLIENT - - ESP_TLS_SERVER */ - - esp_tls_error_handle_t error_handle; /*!< handle to error descriptor */ - #ifdef CONFIG_ESP_TLS_USING_MBEDTLS mbedtls_ssl_context ssl; /*!< TLS/SSL context */ - + mbedtls_entropy_context entropy; /*!< mbedTLS entropy context structure */ - + mbedtls_ctr_drbg_context ctr_drbg; /*!< mbedTLS ctr drbg context structure. - CTR_DRBG is deterministic random + CTR_DRBG is deterministic random bit generation based on AES-256 */ - - mbedtls_ssl_config conf; /*!< TLS/SSL configuration to be shared - between mbedtls_ssl_context + + mbedtls_ssl_config conf; /*!< TLS/SSL configuration to be shared + between mbedtls_ssl_context structures */ - + mbedtls_net_context server_fd; /*!< mbedTLS wrapper type for sockets */ - + mbedtls_x509_crt cacert; /*!< Container for the X.509 CA certificate */ mbedtls_x509_crt *cacert_ptr; /*!< Pointer to the cacert being used. */ @@ -304,6 +295,28 @@ typedef struct esp_tls { void *priv_ctx; void *priv_ssl; #endif + int sockfd; /*!< Underlying socket file descriptor. */ + + ssize_t (*read)(struct esp_tls *tls, char *data, size_t datalen); /*!< Callback function for reading data from TLS/SSL + connection. */ + + ssize_t (*write)(struct esp_tls *tls, const char *data, size_t datalen); /*!< Callback function for writing data to TLS/SSL + connection. */ + + esp_tls_conn_state_t conn_state; /*!< ESP-TLS Connection state */ + + fd_set rset; /*!< read file descriptors */ + + fd_set wset; /*!< write file descriptors */ + + bool is_tls; /*!< indicates connection type (TLS or NON-TLS) */ + + esp_tls_role_t role; /*!< esp-tls role + - ESP_TLS_CLIENT + - ESP_TLS_SERVER */ + + esp_tls_error_handle_t error_handle; /*!< handle to error descriptor */ + } esp_tls_t; @@ -358,7 +371,7 @@ esp_tls_t *esp_tls_conn_new(const char *hostname, int hostlen, int port, const e * @return * - -1 If connection establishment fails. * - 1 If connection establishment is successful. - * - 0 Reserved for connection state is in progress. + * - 0 If connection state is in progress. */ int esp_tls_conn_new_sync(const char *hostname, int hostlen, int port, const esp_tls_cfg_t *cfg, esp_tls_t *tls); @@ -366,7 +379,7 @@ int esp_tls_conn_new_sync(const char *hostname, int hostlen, int port, const esp * @brief Create a new blocking TLS/SSL connection with a given "HTTP" url * * The behaviour is same as esp_tls_conn_new() API. However this API accepts host's url. - * + * * @param[in] url url of host. * @param[in] cfg TLS configuration as esp_tls_cfg_t. If you wish to open * non-TLS connection, keep this NULL. For TLS connection, @@ -375,7 +388,7 @@ int esp_tls_conn_new_sync(const char *hostname, int hostlen, int port, const esp * @return pointer to esp_tls_t, or NULL if connection couldn't be opened. */ esp_tls_t *esp_tls_conn_http_new(const char *url, const esp_tls_cfg_t *cfg); - + /** * @brief Create a new non-blocking TLS/SSL connection * @@ -414,30 +427,30 @@ int esp_tls_conn_http_new_async(const char *url, const esp_tls_cfg_t *cfg, esp_t /** * @brief Write from buffer 'data' into specified tls connection. - * + * * @param[in] tls pointer to esp-tls as esp-tls handle. * @param[in] data Buffer from which data will be written. * @param[in] datalen Length of data buffer. - * - * @return - * - >0 if write operation was successful, the return value is the number - * of bytes actually written to the TLS/SSL connection. + * + * @return + * - >0 if write operation was successful, the return value is the number + * of bytes actually written to the TLS/SSL connection. * - 0 if write operation was not successful. The underlying * connection was closed. - * - <0 if write operation was not successful, because either an - * error occured or an action must be taken by the calling process. + * - <0 if write operation was not successful, because either an + * error occured or an action must be taken by the calling process. */ static inline ssize_t esp_tls_conn_write(esp_tls_t *tls, const void *data, size_t datalen) { - return tls->_write(tls, (char *)data, datalen); + return tls->write(tls, (char *)data, datalen); } /** * @brief Read from specified tls connection into the buffer 'data'. - * + * * @param[in] tls pointer to esp-tls as esp-tls handle. * @param[in] data Buffer to hold read data. - * @param[in] datalen Length of data buffer. + * @param[in] datalen Length of data buffer. * * @return * - >0 if read operation was successful, the return value is the number @@ -449,16 +462,16 @@ static inline ssize_t esp_tls_conn_write(esp_tls_t *tls, const void *data, size_ */ static inline ssize_t esp_tls_conn_read(esp_tls_t *tls, void *data, size_t datalen) { - return tls->_read(tls, (char *)data, datalen); + return tls->read(tls, (char *)data, datalen); } /** * @brief Close the TLS/SSL connection and free any allocated resources. - * - * This function should be called to close each tls connection opened with esp_tls_conn_new() or - * esp_tls_conn_http_new() APIs. * - * @param[in] tls pointer to esp-tls as esp-tls handle. + * This function should be called to close each tls connection opened with esp_tls_conn_new() or + * esp_tls_conn_http_new() APIs. + * + * @param[in] tls pointer to esp-tls as esp-tls handle. */ void esp_tls_conn_delete(esp_tls_t *tls); @@ -477,6 +490,18 @@ void esp_tls_conn_delete(esp_tls_t *tls); */ ssize_t esp_tls_get_bytes_avail(esp_tls_t *tls); +/** + * @brief Returns the connection socket file descriptor from esp_tls session + * + * @param[in] tls handle to esp_tls context + * + * @param[out] sockfd int pointer to sockfd value. + * + * @return - ESP_OK on success and value of sockfd will be updated with socket file descriptor for connection + * - ESP_ERR_INVALID_ARG if (tls == NULL || sockfd == NULL) + */ +esp_err_t esp_tls_get_conn_sockfd(esp_tls_t *tls, int *sockfd); + /** * @brief Create a global CA store, initially empty. * @@ -549,6 +574,7 @@ esp_err_t esp_tls_get_and_clear_last_error(esp_tls_error_handle_t h, int *esp_tl */ mbedtls_x509_crt *esp_tls_get_global_ca_store(void); +#endif /* CONFIG_ESP_TLS_USING_MBEDTLS */ #ifdef CONFIG_ESP_TLS_SERVER /** * @brief Create TLS/SSL server session @@ -576,7 +602,6 @@ int esp_tls_server_session_create(esp_tls_cfg_server_t *cfg, int sockfd, esp_tls */ void esp_tls_server_session_delete(esp_tls_t *tls); #endif /* ! CONFIG_ESP_TLS_SERVER */ -#endif /* CONFIG_ESP_TLS_USING_MBEDTLS */ #ifdef __cplusplus } diff --git a/components/esp-tls/esp_tls_mbedtls.c b/components/esp-tls/esp_tls_mbedtls.c index 702d2b81..245921fd 100644 --- a/components/esp-tls/esp_tls_mbedtls.c +++ b/components/esp-tls/esp_tls_mbedtls.c @@ -26,6 +26,11 @@ #include #include "esp_log.h" +#ifdef CONFIG_MBEDTLS_CERTIFICATE_BUNDLE +#include "esp_crt_bundle.h" +#endif + + static const char *TAG = "esp-tls-mbedtls"; static mbedtls_x509_crt *global_cacert = NULL; @@ -266,7 +271,7 @@ static esp_err_t set_pki_context(esp_tls_t *tls, const esp_tls_pki_t *pki) } ret = mbedtls_pk_parse_key(pki->pk_key, pki->privkey_pem_buf, pki->privkey_pem_bytes, - NULL, 0); + pki->privkey_password, pki->privkey_password_len); if (ret < 0) { ESP_LOGE(TAG, "mbedtls_pk_parse_keyfile returned -0x%x", -ret); ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_MBEDTLS, -ret); @@ -389,16 +394,30 @@ esp_err_t set_client_config(const char *hostname, size_t hostlen, esp_tls_cfg_t return ESP_ERR_MBEDTLS_SSL_CONFIG_DEFAULTS_FAILED; } -#ifdef CONFIG_MBEDTLS_SSL_ALPN + if (cfg->alpn_protos) { +#ifdef CONFIG_MBEDTLS_SSL_ALPN if ((ret = mbedtls_ssl_conf_alpn_protocols(&tls->conf, cfg->alpn_protos) != 0)) { ESP_LOGE(TAG, "mbedtls_ssl_conf_alpn_protocols returned -0x%x", -ret); ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_MBEDTLS, -ret); return ESP_ERR_MBEDTLS_SSL_CONF_ALPN_PROTOCOLS_FAILED; } - } +#else + ESP_LOGE(TAG, "alpn_protos configured but not enabled in menuconfig: Please enable MBEDTLS_SSL_ALPN option"); + return ESP_ERR_INVALID_STATE; #endif - if (cfg->use_global_ca_store == true) { + } + + if (cfg->crt_bundle_attach != NULL) { +#ifdef CONFIG_MBEDTLS_CERTIFICATE_BUNDLE + ESP_LOGD(TAG, "Use certificate bundle"); + mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_REQUIRED); + cfg->crt_bundle_attach(&tls->conf); +#else //CONFIG_MBEDTLS_CERTIFICATE_BUNDLE + ESP_LOGE(TAG, "use_crt_bundle configured but not enabled in menuconfig: Please enable MBEDTLS_CERTIFICATE_BUNDLE option"); + return ESP_ERR_INVALID_STATE; +#endif + } else if (cfg->use_global_ca_store == true) { esp_err_t esp_ret = set_global_ca_store(tls); if (esp_ret != ESP_OK) { return esp_ret; @@ -470,8 +489,8 @@ int esp_mbedtls_server_session_create(esp_tls_cfg_server_t *cfg, int sockfd, esp tls->conn_state = ESP_TLS_FAIL; return -1; } - tls->_read = esp_mbedtls_read; - tls->_write = esp_mbedtls_write; + tls->read = esp_mbedtls_read; + tls->write = esp_mbedtls_write; int ret; while ((ret = mbedtls_ssl_handshake(&tls->ssl)) != 0) { if (ret != ESP_TLS_ERR_SSL_WANT_READ && ret != ESP_TLS_ERR_SSL_WANT_WRITE) { diff --git a/components/esp-tls/esp_tls_wolfssl.c b/components/esp-tls/esp_tls_wolfssl.c index 351ae071..7c99830c 100644 --- a/components/esp-tls/esp_tls_wolfssl.c +++ b/components/esp-tls/esp_tls_wolfssl.c @@ -31,16 +31,86 @@ static unsigned char *global_cacert = NULL; static unsigned int global_cacert_pem_bytes = 0; static const char *TAG = "esp-tls-wolfssl"; -int esp_create_wolfssl_handle(const char *hostname, size_t hostlen, const void *cfg1, esp_tls_t *tls) +/* Prototypes for the static functions */ +static esp_err_t set_client_config(const char *hostname, size_t hostlen, esp_tls_cfg_t *cfg, esp_tls_t *tls); + +#if defined(CONFIG_ESP_TLS_PSK_VERIFICATION) +#include "freertos/semphr.h" +static SemaphoreHandle_t tls_conn_lock; +static inline unsigned int esp_wolfssl_psk_client_cb(WOLFSSL* ssl, const char* hint, char* identity, + unsigned int id_max_len, unsigned char* key,unsigned int key_max_len); +static esp_err_t esp_wolfssl_set_cipher_list(WOLFSSL_CTX *ctx); +#ifdef WOLFSSL_TLS13 +#define PSK_MAX_ID_LEN 128 +#else +#define PSK_MAX_ID_LEN 64 +#endif +#define PSK_MAX_KEY_LEN 64 + +static char psk_id_str[PSK_MAX_ID_LEN]; +static uint8_t psk_key_array[PSK_MAX_KEY_LEN]; +static uint8_t psk_key_max_len = 0; +#endif /* CONFIG_ESP_TLS_PSK_VERIFICATION */ + +#ifdef CONFIG_ESP_TLS_SERVER +static esp_err_t set_server_config(esp_tls_cfg_server_t *cfg, esp_tls_t *tls); +#endif /* CONFIG_ESP_TLS_SERVER */ + +typedef enum x509_file_type { + FILE_TYPE_CA_CERT = 0, /* CA certificate to authenticate entity at other end */ + FILE_TYPE_SELF_CERT, /* Self certificate of the entity */ + FILE_TYPE_SELF_KEY, /* Private key in the self cert-key pair */ +} x509_file_type_t; + +/* Checks whether the certificate provided is in pem format or not */ +static esp_err_t esp_load_wolfssl_verify_buffer(esp_tls_t *tls, const unsigned char *cert_buf, unsigned int cert_len, x509_file_type_t type, int *err_ret) +{ + int wolf_fileformat = WOLFSSL_FILETYPE_DEFAULT; + if (type == FILE_TYPE_SELF_KEY) { + if (cert_buf[cert_len - 1] == '\0' && strstr( (const char *) cert_buf, "-----BEGIN " )) { + wolf_fileformat = WOLFSSL_FILETYPE_PEM; + } else { + wolf_fileformat = WOLFSSL_FILETYPE_ASN1; + } + if ((*err_ret = wolfSSL_CTX_use_PrivateKey_buffer( (WOLFSSL_CTX *)tls->priv_ctx, cert_buf, cert_len, wolf_fileformat)) == WOLFSSL_SUCCESS) { + return ESP_OK; + } + return ESP_FAIL; + } else { + if (cert_buf[cert_len - 1] == '\0' && strstr( (const char *) cert_buf, "-----BEGIN CERTIFICATE-----" )) { + wolf_fileformat = WOLFSSL_FILETYPE_PEM; + } else { + wolf_fileformat = WOLFSSL_FILETYPE_ASN1; + } + if (type == FILE_TYPE_SELF_CERT) { + if ((*err_ret = wolfSSL_CTX_use_certificate_buffer( (WOLFSSL_CTX *)tls->priv_ctx, cert_buf, cert_len, wolf_fileformat)) == WOLFSSL_SUCCESS) { + return ESP_OK; + } + return ESP_FAIL; + } else if (type == FILE_TYPE_CA_CERT) { + if ((*err_ret = wolfSSL_CTX_load_verify_buffer( (WOLFSSL_CTX *)tls->priv_ctx, cert_buf, cert_len, wolf_fileformat)) == WOLFSSL_SUCCESS) { + return ESP_OK; + } + return ESP_FAIL; + } else { + /* Wrong file type provided */ + return ESP_FAIL; + } + } +} + +esp_err_t esp_create_wolfssl_handle(const char *hostname, size_t hostlen, const void *cfg, esp_tls_t *tls) { #ifdef CONFIG_ESP_DEBUG_WOLFSSL wolfSSL_Debugging_ON(); #endif - const esp_tls_cfg_t *cfg = cfg1; + assert(cfg != NULL); assert(tls != NULL); + esp_err_t esp_ret = ESP_FAIL; int ret; + ret = wolfSSL_Init(); if (ret != WOLFSSL_SUCCESS) { ESP_LOGE(TAG, "Init wolfSSL failed: %d", ret); @@ -48,71 +118,206 @@ int esp_create_wolfssl_handle(const char *hostname, size_t hostlen, const void * goto exit; } + if (tls->role == ESP_TLS_CLIENT) { + esp_ret = set_client_config(hostname, hostlen, (esp_tls_cfg_t *)cfg, tls); + if (esp_ret != ESP_OK) { + ESP_LOGE(TAG, "Failed to set client configurations"); + goto exit; + } + } else if (tls->role == ESP_TLS_SERVER) { +#ifdef CONFIG_ESP_TLS_SERVER + esp_ret = set_server_config((esp_tls_cfg_server_t *) cfg, tls); + if (esp_ret != ESP_OK) { + ESP_LOGE(TAG, "Failed to set server configurations"); + goto exit; + } +#else + ESP_LOGE(TAG, "ESP_TLS_SERVER Not enabled in menuconfig"); + goto exit; +#endif + } + else { + ESP_LOGE(TAG, "tls->role is not valid"); + goto exit; + } + + return ESP_OK; +exit: + esp_wolfssl_cleanup(tls); + return esp_ret; +} + +static esp_err_t set_client_config(const char *hostname, size_t hostlen, esp_tls_cfg_t *cfg, esp_tls_t *tls) +{ + int ret = WOLFSSL_FAILURE; tls->priv_ctx = (void *)wolfSSL_CTX_new(wolfTLSv1_2_client_method()); if (!tls->priv_ctx) { ESP_LOGE(TAG, "Set wolfSSL ctx failed"); ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_WOLFSSL, -ret); - goto exit; + return ESP_ERR_WOLFSSL_CTX_SETUP_FAILED; } -#ifdef HAVE_ALPN - if (cfg->alpn_protos) { - char **alpn_list = (char **)cfg->alpn_protos; - for (; *alpn_list != NULL; alpn_list ++) { - if (wolfSSL_UseALPN( (WOLFSSL *)tls->priv_ssl, *alpn_list, strlen(*alpn_list), WOLFSSL_ALPN_FAILED_ON_MISMATCH) != WOLFSSL_SUCCESS) { - ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_WOLFSSL, -ret); - ESP_LOGE(TAG, "Use wolfSSL ALPN failed"); - goto exit; - } + if (cfg->use_global_ca_store == true) { + if ((esp_load_wolfssl_verify_buffer(tls, global_cacert, global_cacert_pem_bytes, FILE_TYPE_CA_CERT, &ret)) != ESP_OK) { + ESP_LOGE(TAG, "Error in loading certificate verify buffer, returned %d", ret); + return ESP_ERR_WOLFSSL_CERT_VERIFY_SETUP_FAILED; } - } -#endif - - if ( cfg->use_global_ca_store == true) { - wolfSSL_CTX_load_verify_buffer( (WOLFSSL_CTX *)tls->priv_ctx, global_cacert, global_cacert_pem_bytes, WOLFSSL_FILETYPE_PEM); - wolfSSL_CTX_set_verify( (WOLFSSL_CTX *)tls->priv_ctx, SSL_VERIFY_PEER, NULL); - } else if (cfg->cacert_pem_buf != NULL) { - wolfSSL_CTX_load_verify_buffer( (WOLFSSL_CTX *)tls->priv_ctx, cfg->cacert_pem_buf, cfg->cacert_pem_bytes, WOLFSSL_FILETYPE_PEM); - wolfSSL_CTX_set_verify( (WOLFSSL_CTX *)tls->priv_ctx, SSL_VERIFY_PEER, NULL); + wolfSSL_CTX_set_verify( (WOLFSSL_CTX *)tls->priv_ctx, WOLFSSL_VERIFY_PEER, NULL); + } else if (cfg->cacert_buf != NULL) { + if ((esp_load_wolfssl_verify_buffer(tls, cfg->cacert_buf, cfg->cacert_bytes, FILE_TYPE_CA_CERT, &ret)) != ESP_OK) { + ESP_LOGE(TAG, "Error in loading certificate verify buffer, returned %d", ret); + return ESP_ERR_WOLFSSL_CERT_VERIFY_SETUP_FAILED; + } + wolfSSL_CTX_set_verify( (WOLFSSL_CTX *)tls->priv_ctx, WOLFSSL_VERIFY_PEER, NULL); } else if (cfg->psk_hint_key) { - ESP_LOGE(TAG,"psk_hint_key not supported in wolfssl"); - goto exit; +#if defined(CONFIG_ESP_TLS_PSK_VERIFICATION) + /*** PSK encryption mode is configured only if no certificate supplied and psk pointer not null ***/ + if(cfg->psk_hint_key->key == NULL || cfg->psk_hint_key->hint == NULL || cfg->psk_hint_key->key_size <= 0) { + ESP_LOGE(TAG, "Please provide appropriate key, keysize and hint to use PSK"); + return ESP_FAIL; + } + /* mutex is given back when call back function executes or in case of failure (at cleanup) */ + if ((xSemaphoreTake(tls_conn_lock, 1000/portTICK_PERIOD_MS) != pdTRUE)) { + ESP_LOGE(TAG, "tls_conn_lock could not be obtained in specified time"); + return -1; + } + ESP_LOGI(TAG, "setting psk configurations"); + if((cfg->psk_hint_key->key_size > PSK_MAX_KEY_LEN) || (strlen(cfg->psk_hint_key->hint) > PSK_MAX_ID_LEN)) { + ESP_LOGE(TAG, "psk key length should be <= %d and identity hint length should be <= %d", PSK_MAX_KEY_LEN, PSK_MAX_ID_LEN); + return ESP_ERR_INVALID_ARG; + } + psk_key_max_len = cfg->psk_hint_key->key_size; + memset(psk_key_array, 0, sizeof(psk_key_array)); + memset(psk_id_str, 0, sizeof(psk_id_str)); + memcpy(psk_key_array, cfg->psk_hint_key->key, psk_key_max_len); + memcpy(psk_id_str, cfg->psk_hint_key->hint, strlen(cfg->psk_hint_key->hint)); + wolfSSL_CTX_set_psk_client_callback( (WOLFSSL_CTX *)tls->priv_ctx, esp_wolfssl_psk_client_cb); + if(esp_wolfssl_set_cipher_list( (WOLFSSL_CTX *)tls->priv_ctx) != ESP_OK) { + ESP_LOGE(TAG, "error in setting cipher-list"); + return ESP_FAIL; + } +#else + ESP_LOGE(TAG, "psk_hint_key configured but not enabled in menuconfig: Please enable ESP_TLS_PSK_VERIFICATION option"); + return ESP_ERR_INVALID_STATE; +#endif } else { - wolfSSL_CTX_set_verify( (WOLFSSL_CTX *)tls->priv_ctx, SSL_VERIFY_NONE, NULL); + wolfSSL_CTX_set_verify( (WOLFSSL_CTX *)tls->priv_ctx, WOLFSSL_VERIFY_NONE, NULL); } - if (cfg->clientcert_pem_buf != NULL && cfg->clientkey_pem_buf != NULL) { - wolfSSL_CTX_use_certificate_buffer( (WOLFSSL_CTX *)tls->priv_ctx, cfg->clientcert_pem_buf, cfg->clientcert_pem_bytes, WOLFSSL_FILETYPE_PEM); - wolfSSL_CTX_use_PrivateKey_buffer( (WOLFSSL_CTX *)tls->priv_ctx, cfg->clientkey_pem_buf, cfg->clientkey_pem_bytes, WOLFSSL_FILETYPE_PEM); - } else if (cfg->clientcert_pem_buf != NULL || cfg->clientkey_pem_buf != NULL) { - ESP_LOGE(TAG, "You have to provide both clientcert_pem_buf and clientkey_pem_buf for mutual authentication\n\n"); - goto exit; + if (cfg->clientcert_buf != NULL && cfg->clientkey_buf != NULL) { + if ((esp_load_wolfssl_verify_buffer(tls,cfg->clientcert_buf, cfg->clientcert_bytes, FILE_TYPE_SELF_CERT, &ret)) != ESP_OK) { + ESP_LOGE(TAG, "Error in loading certificate verify buffer, returned %d", ret); + return ESP_ERR_WOLFSSL_CERT_VERIFY_SETUP_FAILED; + } + if ((esp_load_wolfssl_verify_buffer(tls,cfg->clientkey_buf, cfg->clientkey_bytes, FILE_TYPE_SELF_KEY, &ret)) != ESP_OK) { + ESP_LOGE(TAG, "Error in loading private key verify buffer, returned %d", ret); + return ESP_ERR_WOLFSSL_CERT_VERIFY_SETUP_FAILED; + } + } else if (cfg->clientcert_buf != NULL || cfg->clientkey_buf != NULL) { + ESP_LOGE(TAG, "You have to provide both clientcert_buf and clientkey_buf for mutual authentication\n\n"); + return ESP_FAIL; + } + + if (cfg->crt_bundle_attach != NULL) { + ESP_LOGE(TAG,"use_crt_bundle not supported in wolfssl"); + return ESP_FAIL; } tls->priv_ssl =(void *)wolfSSL_new( (WOLFSSL_CTX *)tls->priv_ctx); if (!tls->priv_ssl) { ESP_LOGE(TAG, "Create wolfSSL failed"); ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_WOLFSSL, -ret); - goto exit; + return ESP_ERR_WOLFSSL_SSL_SETUP_FAILED; } -#ifdef HAVE_SNI - /* Hostname set here should match CN in server certificate */ - char *use_host = strndup(hostname, hostlen); - if (!use_host) { - goto exit; + if (!cfg->skip_common_name) { + char *use_host = NULL; + if (cfg->common_name != NULL) { + use_host = strdup(cfg->common_name); + } else { + use_host = strndup(hostname, hostlen); + } + if (use_host == NULL) { + return ESP_ERR_NO_MEM; + } + /* Hostname set here should match CN in server certificate */ + if ((ret = wolfSSL_set_tlsext_host_name( (WOLFSSL *)tls->priv_ssl, use_host))!= WOLFSSL_SUCCESS) { + ESP_LOGE(TAG, "wolfSSL_set_tlsext_host_name returned -0x%x", -ret); + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_WOLFSSL, -ret); + free(use_host); + return ESP_ERR_WOLFSSL_SSL_SET_HOSTNAME_FAILED; + } + free(use_host); + } + + if (cfg->alpn_protos) { +#ifdef CONFIG_WOLFSSL_HAVE_ALPN + char **alpn_list = (char **)cfg->alpn_protos; + for (; *alpn_list != NULL; alpn_list ++) { + ESP_LOGD(TAG, "alpn protocol is %s", *alpn_list); + if ((ret = wolfSSL_UseALPN( (WOLFSSL *)tls->priv_ssl, *alpn_list, strlen(*alpn_list), WOLFSSL_ALPN_FAILED_ON_MISMATCH)) != WOLFSSL_SUCCESS) { + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_WOLFSSL, -ret); + ESP_LOGE(TAG, "wolfSSL UseALPN failed, returned %d", ret); + return ESP_ERR_WOLFSSL_SSL_CONF_ALPN_PROTOCOLS_FAILED; + } + } +#else + ESP_LOGE(TAG, "CONFIG_WOLFSSL_HAVE_ALPN not enabled in menuconfig"); + return ESP_FAIL; +#endif /* CONFIG_WOLFSSL_HAVE_ALPN */ } - wolfSSL_set_tlsext_host_name( (WOLFSSL *)tls->priv_ssl, use_host); - free(use_host); -#endif wolfSSL_set_fd((WOLFSSL *)tls->priv_ssl, tls->sockfd); - return 0; -exit: - esp_wolfssl_cleanup(tls); - return ret; + return ESP_OK; } +#ifdef CONFIG_ESP_TLS_SERVER +static esp_err_t set_server_config(esp_tls_cfg_server_t *cfg, esp_tls_t *tls) +{ + int ret = WOLFSSL_FAILURE; + tls->priv_ctx = (void *)wolfSSL_CTX_new(wolfTLSv1_2_server_method()); + if (!tls->priv_ctx) { + ESP_LOGE(TAG, "Set wolfSSL ctx failed"); + return ESP_ERR_WOLFSSL_CTX_SETUP_FAILED; + } + + if (cfg->cacert_buf != NULL) { + if ((esp_load_wolfssl_verify_buffer(tls,cfg->cacert_buf, cfg->cacert_bytes, FILE_TYPE_CA_CERT, &ret)) != ESP_OK) { + ESP_LOGE(TAG, "Error in loading certificate verify buffer, returned %d", ret); + return ESP_ERR_WOLFSSL_CERT_VERIFY_SETUP_FAILED; + } + wolfSSL_CTX_set_verify( (WOLFSSL_CTX *)tls->priv_ctx, WOLFSSL_VERIFY_PEER | WOLFSSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL); + ESP_LOGD(TAG," Verify Client for Mutual Auth"); + } else { + ESP_LOGD(TAG," Not verifying Client "); + wolfSSL_CTX_set_verify( (WOLFSSL_CTX *)tls->priv_ctx, WOLFSSL_VERIFY_NONE, NULL); + } + + if (cfg->servercert_buf != NULL && cfg->serverkey_buf != NULL) { + if ((esp_load_wolfssl_verify_buffer(tls,cfg->servercert_buf, cfg->servercert_bytes, FILE_TYPE_SELF_CERT, &ret)) != ESP_OK) { + ESP_LOGE(TAG, "Error in loading certificate verify buffer, returned %d", ret); + return ESP_ERR_WOLFSSL_CERT_VERIFY_SETUP_FAILED; + } + if ((esp_load_wolfssl_verify_buffer(tls,cfg->serverkey_buf, cfg->serverkey_bytes, FILE_TYPE_SELF_KEY, &ret)) != ESP_OK) { + ESP_LOGE(TAG, "Error in loading private key verify buffer, returned %d", ret); + return ESP_ERR_WOLFSSL_CERT_VERIFY_SETUP_FAILED; + } + } else { + ESP_LOGE(TAG, "You have to provide both servercert_buf and serverkey_buf for https_server\n\n"); + return ESP_FAIL; + } + + tls->priv_ssl =(void *)wolfSSL_new( (WOLFSSL_CTX *)tls->priv_ctx); + if (!tls->priv_ssl) { + ESP_LOGE(TAG, "Create wolfSSL failed"); + return ESP_ERR_WOLFSSL_SSL_SETUP_FAILED; + } + + wolfSSL_set_fd((WOLFSSL *)tls->priv_ssl, tls->sockfd); + return ESP_OK; +} +#endif + int esp_wolfssl_handshake(esp_tls_t *tls, const esp_tls_cfg_t *cfg) { int ret; @@ -125,8 +330,8 @@ int esp_wolfssl_handshake(esp_tls_t *tls, const esp_tls_cfg_t *cfg) if (err != ESP_TLS_ERR_SSL_WANT_READ && err != ESP_TLS_ERR_SSL_WANT_WRITE) { ESP_LOGE(TAG, "wolfSSL_connect returned -0x%x", -ret); ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_WOLFSSL, -ret); - - if (cfg->cacert_pem_buf != NULL || cfg->use_global_ca_store == true) { + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_ESP, ESP_ERR_WOLFSSL_SSL_HANDSHAKE_FAILED); + if (cfg->cacert_buf != NULL || cfg->use_global_ca_store == true) { /* This is to check whether handshake failed due to invalid certificate*/ esp_wolfssl_verify_certificate(tls); } @@ -164,7 +369,9 @@ ssize_t esp_wolfssl_write(esp_tls_t *tls, const char *data, size_t datalen) ret = wolfSSL_get_error( (WOLFSSL *)tls->priv_ssl, ret); if (ret != ESP_TLS_ERR_SSL_WANT_READ && ret != ESP_TLS_ERR_SSL_WANT_WRITE) { ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_WOLFSSL, -ret); + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_ESP, ESP_ERR_WOLFSSL_SSL_WRITE_FAILED); ESP_LOGE(TAG, "write error :%d:", ret); + } } return ret; @@ -174,7 +381,7 @@ void esp_wolfssl_verify_certificate(esp_tls_t *tls) { int flags; if ((flags = wolfSSL_get_verify_result( (WOLFSSL *)tls->priv_ssl)) != WOLFSSL_SUCCESS) { - ESP_LOGE(TAG, "Failed to verify peer certificate %d!", flags); + ESP_LOGE(TAG, "Failed to verify peer certificate , returned %d!", flags); ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_WOLFSSL_CERT_FLAGS, flags); } else { ESP_LOGI(TAG, "Certificate verified."); @@ -202,12 +409,61 @@ void esp_wolfssl_cleanup(esp_tls_t *tls) if (!tls) { return; } +#ifdef CONFIG_ESP_TLS_PSK_VERIFICATION + xSemaphoreGive(tls_conn_lock); +#endif /* CONFIG_ESP_TLS_PSK_VERIFICATION */ wolfSSL_shutdown( (WOLFSSL *)tls->priv_ssl); wolfSSL_free( (WOLFSSL *)tls->priv_ssl); + tls->priv_ssl = NULL; wolfSSL_CTX_free( (WOLFSSL_CTX *)tls->priv_ctx); + tls->priv_ctx = NULL; wolfSSL_Cleanup(); } +#ifdef CONFIG_ESP_TLS_SERVER +/** + * @brief Create TLS/SSL server session + */ +int esp_wolfssl_server_session_create(esp_tls_cfg_server_t *cfg, int sockfd, esp_tls_t *tls) +{ + if (tls == NULL || cfg == NULL) { + return -1; + } + tls->role = ESP_TLS_SERVER; + tls->sockfd = sockfd; + esp_err_t esp_ret = esp_create_wolfssl_handle(NULL, 0, cfg, tls); + if (esp_ret != ESP_OK) { + ESP_LOGE(TAG, "create_ssl_handle failed"); + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_ESP, esp_ret); + tls->conn_state = ESP_TLS_FAIL; + return -1; + } + tls->read = esp_wolfssl_read; + tls->write = esp_wolfssl_write; + int ret; + while ((ret = wolfSSL_accept((WOLFSSL *)tls->priv_ssl)) != WOLFSSL_SUCCESS) { + if (ret != ESP_TLS_ERR_SSL_WANT_READ && ret != ESP_TLS_ERR_SSL_WANT_WRITE) { + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_WOLFSSL, -ret); + ESP_LOGE(TAG, "wolfSSL_handshake_server returned %d", ret); + tls->conn_state = ESP_TLS_FAIL; + return ret; + } + } + return 0; +} + +/** + * @brief Close the server side TLS/SSL connection and free any allocated resources. + */ +void esp_wolfssl_server_session_delete(esp_tls_t *tls) +{ + if (tls != NULL) { + esp_wolfssl_cleanup(tls); + free(tls); + } +} +#endif /* CONFIG_ESP_TLS_SERVER */ + esp_err_t esp_wolfssl_init_global_ca_store(void) { /* This function is just to provide consistancy between function calls of esp_tls.h and wolfssl */ @@ -242,3 +498,56 @@ void esp_wolfssl_free_global_ca_store(void) global_cacert_pem_bytes = 0; } } + +#if defined(CONFIG_ESP_TLS_PSK_VERIFICATION) +static esp_err_t esp_wolfssl_set_cipher_list(WOLFSSL_CTX *ctx) +{ + const char *defaultCipherList; + int ret; +#if defined(HAVE_AESGCM) && !defined(NO_DH) +#ifdef WOLFSSL_TLS13 + defaultCipherList = "DHE-PSK-AES128-GCM-SHA256:" + "TLS13-AES128-GCM-SHA256"; +#else + defaultCipherList = "DHE-PSK-AES128-GCM-SHA256"; +#endif +#elif defined(HAVE_NULL_CIPHER) + defaultCipherList = "PSK-NULL-SHA256"; +#else + defaultCipherList = "PSK-AES128-CBC-SHA256"; +#endif + ESP_LOGD(TAG, "cipher list is %s", defaultCipherList); + if ((ret = wolfSSL_CTX_set_cipher_list(ctx,defaultCipherList)) != WOLFSSL_SUCCESS) { + wolfSSL_CTX_free(ctx); + ESP_LOGE(TAG, "can't set cipher list, returned %02x", ret); + return ESP_FAIL; + } + return ESP_OK; +} + +/* initialize the mutex before app_main() when using PSK */ +static void __attribute__((constructor)) +espt_tls_wolfssl_init_conn_lock (void) +{ + if ((tls_conn_lock = xSemaphoreCreateMutex()) == NULL) { + ESP_EARLY_LOGE(TAG, "mutex for tls psk connection could not be created"); + } +} + +/* Some callback functions required by PSK */ +static inline unsigned int esp_wolfssl_psk_client_cb(WOLFSSL* ssl, const char* hint, + char* identity, unsigned int id_max_len, unsigned char* key, + unsigned int key_max_len) +{ + (void)key_max_len; + + /* see internal.h MAX_PSK_ID_LEN for PSK identity limit */ + memcpy(identity, psk_id_str, id_max_len); + for(int count = 0; count < psk_key_max_len; count ++) { + key[count] = psk_key_array[count]; + } + xSemaphoreGive(tls_conn_lock); + return psk_key_max_len; + /* return length of key in octets or 0 or for error */ +} +#endif /* CONFIG_ESP_TLS_PSK_VERIFICATION */ diff --git a/components/esp-tls/private_include/esp_tls_wolfssl.h b/components/esp-tls/private_include/esp_tls_wolfssl.h index 73cb9f2f..a04ad796 100644 --- a/components/esp-tls/private_include/esp_tls_wolfssl.h +++ b/components/esp-tls/private_include/esp_tls_wolfssl.h @@ -70,3 +70,17 @@ void esp_wolfssl_free_global_ca_store(void); * Callback function for Initializing the global ca store for TLS?SSL using wolfssl */ esp_err_t esp_wolfssl_init_global_ca_store(void); + +#ifdef CONFIG_ESP_TLS_SERVER + +/** + * Function to Create ESP-TLS Server session with wolfssl Stack + */ +int esp_wolfssl_server_session_create(esp_tls_cfg_server_t *cfg, int sockfd, esp_tls_t *tls); + +/* + * Delete Server Session + */ +void esp_wolfssl_server_session_delete(esp_tls_t *tls); + +#endif diff --git a/components/esp_http_client/lib/include/http_utils.h b/components/esp_http_client/lib/include/http_utils.h index d5784022..14c3b10c 100644 --- a/components/esp_http_client/lib/include/http_utils.h +++ b/components/esp_http_client/lib/include/http_utils.h @@ -16,7 +16,7 @@ #ifndef _HTTP_UTILS_H_ #define _HTTP_UTILS_H_ #include -#include "esp_transport_utils.h" + /** * @brief Assign new_str to *str pointer, and realloc *str if it not NULL * @@ -80,7 +80,9 @@ char *http_utils_join_string(const char *first_str, int len_first, const char *s int http_utils_str_starts_with(const char *str, const char *start); -#define HTTP_MEM_CHECK(TAG, a, action) ESP_TRANSPORT_MEM_CHECK(TAG, a, action) - +#define HTTP_MEM_CHECK(TAG, a, action) if (!(a)) { \ + ESP_LOGE(TAG,"%s:%d (%s): %s", __FILE__, __LINE__, __FUNCTION__, "Memory exhausted"); \ + action; \ + } #endif diff --git a/components/tcp_transport/include/esp_transport.h b/components/tcp_transport/include/esp_transport.h index 39e694f0..4841725d 100644 --- a/components/tcp_transport/include/esp_transport.h +++ b/components/tcp_transport/include/esp_transport.h @@ -133,7 +133,7 @@ esp_err_t esp_transport_set_default_port(esp_transport_handle_t t, int port); * @param t The transport handle * @param[in] host Hostname * @param[in] port Port - * @param[in] timeout_ms The timeout milliseconds + * @param[in] timeout_ms The timeout milliseconds (-1 indicates wait forever) * * @return * - socket for will use by this transport @@ -147,7 +147,7 @@ int esp_transport_connect(esp_transport_handle_t t, const char *host, int port, * @param t The transport handle * @param[in] host Hostname * @param[in] port Port - * @param[in] timeout_ms The timeout milliseconds + * @param[in] timeout_ms The timeout milliseconds (-1 indicates wait forever) * * @return * - socket for will use by this transport @@ -161,7 +161,7 @@ int esp_transport_connect_async(esp_transport_handle_t t, const char *host, int * @param t The transport handle * @param buffer The buffer * @param[in] len The length - * @param[in] timeout_ms The timeout milliseconds + * @param[in] timeout_ms The timeout milliseconds (-1 indicates wait forever) * * @return * - Number of bytes was read @@ -173,7 +173,7 @@ int esp_transport_read(esp_transport_handle_t t, char *buffer, int len, int time * @brief Poll the transport until readable or timeout * * @param[in] t The transport handle - * @param[in] timeout_ms The timeout milliseconds + * @param[in] timeout_ms The timeout milliseconds (-1 indicates wait forever) * * @return * - 0 Timeout @@ -188,7 +188,7 @@ int esp_transport_poll_read(esp_transport_handle_t t, int timeout_ms); * @param t The transport handle * @param buffer The buffer * @param[in] len The length - * @param[in] timeout_ms The timeout milliseconds + * @param[in] timeout_ms The timeout milliseconds (-1 indicates wait forever) * * @return * - Number of bytes was written @@ -200,7 +200,7 @@ int esp_transport_write(esp_transport_handle_t t, const char *buffer, int len, i * @brief Poll the transport until writeable or timeout * * @param[in] t The transport handle - * @param[in] timeout_ms The timeout milliseconds + * @param[in] timeout_ms The timeout milliseconds (-1 indicates wait forever) * * @return * - 0 Timeout diff --git a/components/tcp_transport/include/esp_transport_ssl.h b/components/tcp_transport/include/esp_transport_ssl.h index a83e9388..9ce0b19c 100644 --- a/components/tcp_transport/include/esp_transport_ssl.h +++ b/components/tcp_transport/include/esp_transport_ssl.h @@ -92,6 +92,16 @@ void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const */ void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len); +/** + * @brief Set SSL client key password if the key is password protected. The configured + * password is passed to the underlying TLS stack to decrypt the client key + * + * @param t ssl transport + * @param[in] password Pointer to the password + * @param[in] password_len Password length + */ +void esp_transport_ssl_set_client_key_password(esp_transport_handle_t t, const char *password, int password_len); + /** * @brief Set SSL client key data for mutual authentication (as DER format). * Note that, this function stores the pointer to data, rather than making a copy. @@ -103,6 +113,16 @@ void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char */ void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len); +/** + * @brief Set the list of supported application protocols to be used with ALPN. + * Note that, this function stores the pointer to data, rather than making a copy. + * So this data must remain valid until after the connection is cleaned up + * + * @param t ssl transport + * @param[in] alpn_porot The list of ALPN protocols, the last entry must be NULL + */ +void esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t, const char **alpn_protos); + /** * @brief Skip validation of certificate's common name field * diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index 0876480a..5e540579 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -14,6 +14,7 @@ extern "C" { #endif typedef enum ws_transport_opcodes { + WS_TRANSPORT_OPCODES_CONT = 0x00, WS_TRANSPORT_OPCODES_TEXT = 0x01, WS_TRANSPORT_OPCODES_BINARY = 0x02, WS_TRANSPORT_OPCODES_CLOSE = 0x08, @@ -50,6 +51,30 @@ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path); */ esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char *sub_protocol); +/** + * @brief Set websocket user-agent header + * + * @param t websocket transport handle + * @param sub_protocol user-agent string + * + * @return + * - ESP_OK on success + * - One of the error codes + */ +esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char *user_agent); + +/** + * @brief Set websocket additional headers + * + * @param t websocket transport handle + * @param sub_protocol additional header strings each terminated with \r\n + * + * @return + * - ESP_OK on success + * - One of the error codes + */ +esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers); + /** * @brief Sends websocket raw message with custom opcode and payload * @@ -63,7 +88,7 @@ esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char * @param[in] opcode ws operation code * @param[in] buffer The buffer * @param[in] len The length - * @param[in] timeout_ms The timeout milliseconds + * @param[in] timeout_ms The timeout milliseconds (-1 indicates block forever) * * @return * - Number of bytes was written @@ -81,6 +106,16 @@ int esp_transport_ws_send_raw(esp_transport_handle_t t, ws_transport_opcodes_t o */ ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t); +/** + * @brief Returns payload length of the last received data + * + * @param t websocket transport handle + * + * @return + * - Number of bytes in the payload + */ +int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t); + #ifdef __cplusplus } diff --git a/components/tcp_transport/include/esp_transport_utils.h b/components/tcp_transport/private_include/esp_transport_utils.h similarity index 65% rename from components/tcp_transport/include/esp_transport_utils.h rename to components/tcp_transport/private_include/esp_transport_utils.h index 6a9d1d02..dbc91b11 100644 --- a/components/tcp_transport/include/esp_transport_utils.h +++ b/components/tcp_transport/private_include/esp_transport_utils.h @@ -30,12 +30,17 @@ extern "C" { } /** - * @brief Convert milliseconds to timeval struct + * @brief Convert milliseconds to timeval struct for valid timeouts, otherwise + * (if "wait forever" requested by timeout_ms=-1) timeval structure is not updated and NULL returned * - * @param[in] timeout_ms The timeout milliseconds - * @param[out] tv Timeval struct + * @param[in] timeout_ms The timeout value in milliseconds or -1 to waiting forever + * @param[out] tv Pointer to timeval struct + * + * @return + * - NULL if timeout_ms=-1 (wait forever) + * - pointer to the updated timeval structure (provided as "tv" argument) with recalculated timeout value */ -void esp_transport_utils_ms_to_timeval(int timeout_ms, struct timeval *tv); +struct timeval* esp_transport_utils_ms_to_timeval(int timeout_ms, struct timeval *tv); #ifdef __cplusplus diff --git a/components/tcp_transport/test/CMakeLists.txt b/components/tcp_transport/test/CMakeLists.txt index 88504a27..89846a75 100644 --- a/components/tcp_transport/test/CMakeLists.txt +++ b/components/tcp_transport/test/CMakeLists.txt @@ -1,5 +1,3 @@ -set(COMPONENT_SRCDIRS ".") -set(COMPONENT_PRIV_INCLUDEDIRS "../private_include" ".") -set(COMPONENT_PRIV_REQUIRES unity test_utils tcp_transport) - -register_component() \ No newline at end of file +idf_component_register(SRC_DIRS "." + PRIV_INCLUDE_DIRS "../private_include" "." + PRIV_REQUIRES unity test_utils tcp_transport) \ No newline at end of file diff --git a/components/tcp_transport/transport_ssl.c b/components/tcp_transport/transport_ssl.c index b92c2115..4bca105d 100644 --- a/components/tcp_transport/transport_ssl.c +++ b/components/tcp_transport/transport_ssl.c @@ -71,7 +71,7 @@ static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int ssl->cfg.timeout_ms = timeout_ms; ssl->ssl_initialized = true; ssl->tls = esp_tls_init(); - if (esp_tls_conn_new_sync(host, strlen(host), port, &ssl->cfg, ssl->tls) < 0) { + if (esp_tls_conn_new_sync(host, strlen(host), port, &ssl->cfg, ssl->tls) <= 0) { ESP_LOGE(TAG, "Failed to open a new connection"); esp_transport_set_errors(t, ssl->tls->error_handle); esp_tls_conn_delete(ssl->tls); @@ -86,16 +86,15 @@ static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms) { transport_ssl_t *ssl = esp_transport_get_context_data(t); int ret = -1; + struct timeval timeout; fd_set readset; fd_set errset; FD_ZERO(&readset); FD_ZERO(&errset); FD_SET(ssl->tls->sockfd, &readset); FD_SET(ssl->tls->sockfd, &errset); - struct timeval timeout; - esp_transport_utils_ms_to_timeval(timeout_ms, &timeout); - ret = select(ssl->tls->sockfd + 1, &readset, NULL, &errset, &timeout); + ret = select(ssl->tls->sockfd + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout)); if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) { int sock_errno = 0; uint32_t optlen = sizeof(sock_errno); @@ -110,15 +109,14 @@ static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms) { transport_ssl_t *ssl = esp_transport_get_context_data(t); int ret = -1; + struct timeval timeout; fd_set writeset; fd_set errset; FD_ZERO(&writeset); FD_ZERO(&errset); FD_SET(ssl->tls->sockfd, &writeset); FD_SET(ssl->tls->sockfd, &errset); - struct timeval timeout; - esp_transport_utils_ms_to_timeval(timeout_ms, &timeout); - ret = select(ssl->tls->sockfd + 1, NULL, &writeset, &errset, &timeout); + ret = select(ssl->tls->sockfd + 1, NULL, &writeset, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout)); if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) { int sock_errno = 0; uint32_t optlen = sizeof(sock_errno); @@ -247,6 +245,15 @@ void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char } } +void esp_transport_ssl_set_client_key_password(esp_transport_handle_t t, const char *password, int password_len) +{ + transport_ssl_t *ssl = esp_transport_get_context_data(t); + if (t && ssl) { + ssl->cfg.clientkey_password = (void *)password; + ssl->cfg.clientkey_password_len = password_len; + } +} + void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len) { transport_ssl_t *ssl = esp_transport_get_context_data(t); @@ -256,6 +263,14 @@ void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const c } } +void esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t, const char **alpn_protos) +{ + transport_ssl_t *ssl = esp_transport_get_context_data(t); + if (t && ssl) { + ssl->cfg.alpn_protos = alpn_protos; + } +} + void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t) { transport_ssl_t *ssl = esp_transport_get_context_data(t); diff --git a/components/tcp_transport/transport_tcp.c b/components/tcp_transport/transport_tcp.c index 3fba399a..5bfb99dd 100644 --- a/components/tcp_transport/transport_tcp.c +++ b/components/tcp_transport/transport_tcp.c @@ -52,7 +52,7 @@ static int resolve_dns(const char *host, struct sockaddr_in *ip) { static int tcp_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) { struct sockaddr_in remote_ip; - struct timeval tv; + struct timeval tv = { 0 }; transport_tcp_t *tcp = esp_transport_get_context_data(t); bzero(&remote_ip, sizeof(struct sockaddr_in)); @@ -74,7 +74,7 @@ static int tcp_connect(esp_transport_handle_t t, const char *host, int port, int remote_ip.sin_family = AF_INET; remote_ip.sin_port = htons(port); - esp_transport_utils_ms_to_timeval(timeout_ms, &tv); + esp_transport_utils_ms_to_timeval(timeout_ms, &tv); // if timeout=-1, tv is unchanged, 0, i.e. waits forever setsockopt(tcp->sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); setsockopt(tcp->sock, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); @@ -117,15 +117,15 @@ static int tcp_poll_read(esp_transport_handle_t t, int timeout_ms) { transport_tcp_t *tcp = esp_transport_get_context_data(t); int ret = -1; + struct timeval timeout; fd_set readset; fd_set errset; FD_ZERO(&readset); FD_ZERO(&errset); FD_SET(tcp->sock, &readset); FD_SET(tcp->sock, &errset); - struct timeval timeout; - esp_transport_utils_ms_to_timeval(timeout_ms, &timeout); - ret = select(tcp->sock + 1, &readset, NULL, &errset, &timeout); + + ret = select(tcp->sock + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout)); if (ret > 0 && FD_ISSET(tcp->sock, &errset)) { int sock_errno = 0; uint32_t optlen = sizeof(sock_errno); @@ -140,15 +140,15 @@ static int tcp_poll_write(esp_transport_handle_t t, int timeout_ms) { transport_tcp_t *tcp = esp_transport_get_context_data(t); int ret = -1; + struct timeval timeout; fd_set writeset; fd_set errset; FD_ZERO(&writeset); FD_ZERO(&errset); FD_SET(tcp->sock, &writeset); FD_SET(tcp->sock, &errset); - struct timeval timeout; - esp_transport_utils_ms_to_timeval(timeout_ms, &timeout); - ret = select(tcp->sock + 1, NULL, &writeset, &errset, &timeout); + + ret = select(tcp->sock + 1, NULL, &writeset, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout)); if (ret > 0 && FD_ISSET(tcp->sock, &errset)) { int sock_errno = 0; uint32_t optlen = sizeof(sock_errno); diff --git a/components/tcp_transport/transport_utils.c b/components/tcp_transport/transport_utils.c index 9ef56dc2..5e0a0216 100644 --- a/components/tcp_transport/transport_utils.c +++ b/components/tcp_transport/transport_utils.c @@ -6,8 +6,12 @@ #include "esp_transport_utils.h" -void esp_transport_utils_ms_to_timeval(int timeout_ms, struct timeval *tv) +struct timeval* esp_transport_utils_ms_to_timeval(int timeout_ms, struct timeval *tv) { + if (timeout_ms == -1) { + return NULL; + } tv->tv_sec = timeout_ms / 1000; tv->tv_usec = (timeout_ms - (tv->tv_sec * 1000)) * 1000; + return tv; } \ No newline at end of file diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 637d4d61..3cff4198 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -2,7 +2,6 @@ #include #include #include - #include "esp_log.h" #include "esp_transport.h" #include "esp_transport_tcp.h" @@ -15,11 +14,13 @@ static const char *TAG = "TRANSPORT_WS"; #define DEFAULT_WS_BUFFER (1024) #define WS_FIN 0x80 +#define WS_OPCODE_CONT 0x00 #define WS_OPCODE_TEXT 0x01 #define WS_OPCODE_BINARY 0x02 #define WS_OPCODE_CLOSE 0x08 #define WS_OPCODE_PING 0x09 #define WS_OPCODE_PONG 0x0a + // Second byte #define WS_MASK 0x80 #define WS_SIZE16 126 @@ -27,11 +28,21 @@ static const char *TAG = "TRANSPORT_WS"; #define MAX_WEBSOCKET_HEADER_SIZE 16 #define WS_RESPONSE_OK 101 + +typedef struct { + uint8_t opcode; + char mask_key[4]; /*!< Mask key for this payload */ + int payload_len; /*!< Total length of the payload */ + int bytes_remaining; /*!< Bytes left to read of the payload */ +} ws_transport_frame_state_t; + typedef struct { char *path; char *buffer; char *sub_protocol; - uint8_t read_opcode; + char *user_agent; + char *headers; + ws_transport_frame_state_t frame_state; esp_transport_handle_t parent; } transport_ws_t; @@ -43,6 +54,11 @@ static inline uint8_t ws_get_bin_opcode(ws_transport_opcodes_t opcode) static esp_transport_handle_t ws_get_payload_transport_handle(esp_transport_handle_t t) { transport_ws_t *ws = esp_transport_get_context_data(t); + + /* Reading parts of a frame directly will disrupt the WS internal frame state, + reset bytes_remaining to prepare for reading a new frame */ + ws->frame_state.bytes_remaining = 0; + return ws->parent; } @@ -96,24 +112,27 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int // Size of base64 coded string is equal '((input_size * 4) / 3) + (input_size / 96) + 6' including Z-term unsigned char client_key[28] = {0}; + const char *user_agent_ptr = (ws->user_agent)?(ws->user_agent):"ESP32 Websocket Client"; + size_t outlen = 0; mbedtls_base64_encode(client_key, sizeof(client_key), &outlen, random_key, sizeof(random_key)); int len = snprintf(ws->buffer, DEFAULT_WS_BUFFER, "GET %s HTTP/1.1\r\n" "Connection: Upgrade\r\n" "Host: %s:%d\r\n" + "User-Agent: %s\r\n" "Upgrade: websocket\r\n" "Sec-WebSocket-Version: 13\r\n" - "Sec-WebSocket-Key: %s\r\n" - "User-Agent: ESP32 Websocket Client\r\n", + "Sec-WebSocket-Key: %s\r\n", ws->path, - host, port, + host, port, user_agent_ptr, client_key); if (len <= 0 || len >= DEFAULT_WS_BUFFER) { ESP_LOGE(TAG, "Error in request generation, %d", len); return -1; } if (ws->sub_protocol) { + ESP_LOGD(TAG, "sub_protocol: %s", ws->sub_protocol); int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "Sec-WebSocket-Protocol: %s\r\n", ws->sub_protocol); len += r; if (r <= 0 || len >= DEFAULT_WS_BUFFER) { @@ -122,6 +141,16 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int return -1; } } + if (ws->headers) { + ESP_LOGD(TAG, "headers: %s", ws->headers); + int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "%s", ws->headers); + len += r; + if (r <= 0 || len >= DEFAULT_WS_BUFFER) { + ESP_LOGE(TAG, "Error in request generation" + "(strncpy of headers returned %d, desired request len: %d, buffer size: %d", r, len, DEFAULT_WS_BUFFER); + return -1; + } + } int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "\r\n"); len += r; if (r <= 0 || len >= DEFAULT_WS_BUFFER) { @@ -233,7 +262,7 @@ static int _ws_write(esp_transport_handle_t t, int opcode, int mask_flag, const for (i = 0; i < len; ++i) { buffer[i] = (buffer[i] ^ mask[i % 4]); } - } + } return ret; } @@ -260,12 +289,46 @@ static int ws_write(esp_transport_handle_t t, const char *b, int len, int timeou return _ws_write(t, WS_OPCODE_BINARY | WS_FIN, WS_MASK, b, len, timeout_ms); } -static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) + +static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) +{ + transport_ws_t *ws = esp_transport_get_context_data(t); + + int bytes_to_read; + int rlen = 0; + + if (ws->frame_state.bytes_remaining > len) { + ESP_LOGD(TAG, "Actual data to receive (%d) are longer than ws buffer (%d)", ws->frame_state.bytes_remaining, len); + bytes_to_read = len; + + } else { + bytes_to_read = ws->frame_state.bytes_remaining; + } + + // Receive and process payload + if (bytes_to_read != 0 && (rlen = esp_transport_read(ws->parent, buffer, bytes_to_read, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error read data"); + return rlen; + } + ws->frame_state.bytes_remaining -= rlen; + + if (ws->frame_state.mask_key) { + for (int i = 0; i < bytes_to_read; i++) { + buffer[i] = (buffer[i] ^ ws->frame_state.mask_key[i % 4]); + } + } + return rlen; +} + + +/* Read and parse the WS header, determine length of payload */ +static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) { transport_ws_t *ws = esp_transport_get_context_data(t); int payload_len; + char ws_header[MAX_WEBSOCKET_HEADER_SIZE]; - char *data_ptr = ws_header, mask, *mask_key = NULL; + char *data_ptr = ws_header, mask; int rlen; int poll_read; if ((poll_read = esp_transport_poll_read(ws->parent, timeout_ms)) <= 0) { @@ -274,16 +337,17 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ // Receive and process header first (based on header size) int header = 2; + int mask_len = 4; if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } - ws->read_opcode = (*data_ptr & 0x0F); + ws->frame_state.opcode = (*data_ptr & 0x0F); data_ptr ++; mask = ((*data_ptr >> 7) & 0x01); payload_len = (*data_ptr & 0x7F); data_ptr++; - ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", ws->read_opcode, mask, payload_len); + ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", ws->frame_state.opcode, mask, payload_len); if (payload_len == 126) { // headerLen += 2; if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { @@ -307,27 +371,48 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ } } - if (payload_len > len) { - ESP_LOGD(TAG, "Actual data to receive (%d) are longer than ws buffer (%d)", payload_len, len); - payload_len = len; - } - - // Then receive and process payload - if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, payload_len, timeout_ms)) <= 0) { - ESP_LOGE(TAG, "Error read data"); - return rlen; - } - if (mask) { - mask_key = buffer; - data_ptr = buffer + 4; - for (int i = 0; i < payload_len; i++) { - buffer[i] = (data_ptr[i] ^ mask_key[i % 4]); + // Read and store mask + if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, mask_len, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error read data"); + return rlen; } + memcpy(ws->frame_state.mask_key, buffer, mask_len); + } else { + memset(ws->frame_state.mask_key, 0, mask_len); } + + ws->frame_state.payload_len = payload_len; + ws->frame_state.bytes_remaining = payload_len; + return payload_len; } +static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) +{ + int rlen = 0; + transport_ws_t *ws = esp_transport_get_context_data(t); + + // If message exceeds buffer len then subsequent reads will skip reading header and read whatever is left of the payload + if (ws->frame_state.bytes_remaining <= 0) { + if ( (rlen = ws_read_header(t, buffer, len, timeout_ms)) <= 0) { + // If something when wrong then we prepare for reading a new header + ws->frame_state.bytes_remaining = 0; + return rlen; + } + } + if (ws->frame_state.payload_len) { + if ( (rlen = ws_read_payload(t, buffer, len, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error reading payload data"); + ws->frame_state.bytes_remaining = 0; + return rlen; + } + } + + return rlen; +} + + static int ws_poll_read(esp_transport_handle_t t, int timeout_ms) { transport_ws_t *ws = esp_transport_get_context_data(t); @@ -352,6 +437,8 @@ static esp_err_t ws_destroy(esp_transport_handle_t t) free(ws->buffer); free(ws->path); free(ws->sub_protocol); + free(ws->user_agent); + free(ws->headers); free(ws); return 0; } @@ -409,8 +496,56 @@ esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char return ESP_OK; } +esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char *user_agent) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + transport_ws_t *ws = esp_transport_get_context_data(t); + if (ws->user_agent) { + free(ws->user_agent); + } + if (user_agent == NULL) { + ws->user_agent = NULL; + return ESP_OK; + } + ws->user_agent = strdup(user_agent); + if (ws->user_agent == NULL) { + return ESP_ERR_NO_MEM; + } + return ESP_OK; +} + +esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + transport_ws_t *ws = esp_transport_get_context_data(t); + if (ws->headers) { + free(ws->headers); + } + if (headers == NULL) { + ws->headers = NULL; + return ESP_OK; + } + ws->headers = strdup(headers); + if (ws->headers == NULL) { + return ESP_ERR_NO_MEM; + } + return ESP_OK; +} + ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t) { transport_ws_t *ws = esp_transport_get_context_data(t); - return ws->read_opcode; + return ws->frame_state.opcode; } + +int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t) +{ + transport_ws_t *ws = esp_transport_get_context_data(t); + return ws->frame_state.payload_len; +} + +