mirror of
https://github.com/78/xiaozhi-esp32.git
synced 2025-08-06 10:19:44 +08:00
add settings
This commit is contained in:
@ -4,7 +4,7 @@
|
|||||||
# CMakeLists in this exact order for cmake to work correctly
|
# CMakeLists in this exact order for cmake to work correctly
|
||||||
cmake_minimum_required(VERSION 3.16)
|
cmake_minimum_required(VERSION 3.16)
|
||||||
|
|
||||||
set(PROJECT_VER "0.8.0")
|
set(PROJECT_VER "0.8.1")
|
||||||
|
|
||||||
include($ENV{IDF_PATH}/tools/cmake/project.cmake)
|
include($ENV{IDF_PATH}/tools/cmake/project.cmake)
|
||||||
project(xiaozhi)
|
project(xiaozhi)
|
||||||
|
@ -16,6 +16,7 @@ set(SOURCES "audio_codec.cc"
|
|||||||
"button.cc"
|
"button.cc"
|
||||||
"led.cc"
|
"led.cc"
|
||||||
"ota.cc"
|
"ota.cc"
|
||||||
|
"settings.cc"
|
||||||
"main.cc"
|
"main.cc"
|
||||||
)
|
)
|
||||||
set(INCLUDE_DIRS ".")
|
set(INCLUDE_DIRS ".")
|
||||||
|
@ -37,9 +37,6 @@ Application::~Application() {
|
|||||||
if (audio_encode_task_stack_ != nullptr) {
|
if (audio_encode_task_stack_ != nullptr) {
|
||||||
heap_caps_free(audio_encode_task_stack_);
|
heap_caps_free(audio_encode_task_stack_);
|
||||||
}
|
}
|
||||||
if (main_loop_task_stack_ != nullptr) {
|
|
||||||
heap_caps_free(main_loop_task_stack_);
|
|
||||||
}
|
|
||||||
|
|
||||||
vEventGroupDelete(event_group_);
|
vEventGroupDelete(event_group_);
|
||||||
}
|
}
|
||||||
@ -47,20 +44,34 @@ Application::~Application() {
|
|||||||
void Application::CheckNewVersion() {
|
void Application::CheckNewVersion() {
|
||||||
// Check if there is a new firmware version available
|
// Check if there is a new firmware version available
|
||||||
ota_.SetPostData(Board::GetInstance().GetJson());
|
ota_.SetPostData(Board::GetInstance().GetJson());
|
||||||
ota_.CheckVersion();
|
|
||||||
if (ota_.HasNewVersion()) {
|
while (true) {
|
||||||
SetChatState(kChatStateUpgrading);
|
if (ota_.CheckVersion()) {
|
||||||
ota_.StartUpgrade([](int progress, size_t speed) {
|
if (ota_.HasNewVersion()) {
|
||||||
char buffer[64];
|
// Wait for the chat state to be idle
|
||||||
snprintf(buffer, sizeof(buffer), "Upgrading...\n %d%% %zuKB/s", progress, speed / 1024);
|
while (chat_state_ != kChatStateIdle) {
|
||||||
auto display = Board::GetInstance().GetDisplay();
|
vTaskDelay(100);
|
||||||
display->SetText(buffer);
|
}
|
||||||
});
|
|
||||||
// If upgrade success, the device will reboot and never reach here
|
SetChatState(kChatStateUpgrading);
|
||||||
ESP_LOGI(TAG, "Firmware upgrade failed...");
|
ota_.StartUpgrade([](int progress, size_t speed) {
|
||||||
SetChatState(kChatStateIdle);
|
char buffer[64];
|
||||||
} else {
|
snprintf(buffer, sizeof(buffer), "Upgrading...\n %d%% %zuKB/s", progress, speed / 1024);
|
||||||
ota_.MarkCurrentVersionValid();
|
auto display = Board::GetInstance().GetDisplay();
|
||||||
|
display->SetText(buffer);
|
||||||
|
});
|
||||||
|
|
||||||
|
// If upgrade success, the device will reboot and never reach here
|
||||||
|
ESP_LOGI(TAG, "Firmware upgrade failed...");
|
||||||
|
SetChatState(kChatStateIdle);
|
||||||
|
} else {
|
||||||
|
ota_.MarkCurrentVersionValid();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check again in 60 seconds
|
||||||
|
vTaskDelay(pdMS_TO_TICKS(60000));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,24 +202,18 @@ void Application::Start() {
|
|||||||
/* Wait for the network to be ready */
|
/* Wait for the network to be ready */
|
||||||
board.StartNetwork();
|
board.StartNetwork();
|
||||||
|
|
||||||
const size_t main_loop_stack_size = 4096 * 8;
|
xTaskCreate([](void* arg) {
|
||||||
main_loop_task_stack_ = (StackType_t*)heap_caps_malloc(main_loop_stack_size, MALLOC_CAP_SPIRAM);
|
|
||||||
xTaskCreateStatic([](void* arg) {
|
|
||||||
Application* app = (Application*)arg;
|
Application* app = (Application*)arg;
|
||||||
app->MainLoop();
|
app->MainLoop();
|
||||||
vTaskDelete(NULL);
|
vTaskDelete(NULL);
|
||||||
}, "main_loop", main_loop_stack_size, this, 1, main_loop_task_stack_, &main_loop_task_buffer_);
|
}, "main_loop", 4096 * 2, this, 1, nullptr);
|
||||||
|
|
||||||
// Check for new firmware version or get the MQTT broker address
|
// Check for new firmware version or get the MQTT broker address
|
||||||
while (true) {
|
xTaskCreate([](void* arg) {
|
||||||
CheckNewVersion();
|
Application* app = (Application*)arg;
|
||||||
|
app->CheckNewVersion();
|
||||||
if (ota_.HasMqttConfig()) {
|
vTaskDelete(NULL);
|
||||||
break;
|
}, "check_new_version", 4096 * 2, this, 1, nullptr);
|
||||||
}
|
|
||||||
Alert("Error", "Missing MQTT config");
|
|
||||||
vTaskDelay(pdMS_TO_TICKS(10000));
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef CONFIG_USE_AFE_SR
|
#ifdef CONFIG_USE_AFE_SR
|
||||||
audio_processor_.Initialize(codec->input_channels(), codec->input_reference());
|
audio_processor_.Initialize(codec->input_channels(), codec->input_reference());
|
||||||
@ -264,12 +269,19 @@ void Application::Start() {
|
|||||||
|
|
||||||
// Initialize the protocol
|
// Initialize the protocol
|
||||||
display->SetText("Starting\nProtocol...");
|
display->SetText("Starting\nProtocol...");
|
||||||
protocol_ = new MqttProtocol(ota_.GetMqttConfig());
|
protocol_ = new MqttProtocol();
|
||||||
protocol_->OnIncomingAudio([this](const std::string& data) {
|
protocol_->OnIncomingAudio([this](const std::string& data) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
audio_decode_queue_.emplace_back(std::move(data));
|
audio_decode_queue_.emplace_back(std::move(data));
|
||||||
cv_.notify_all();
|
cv_.notify_all();
|
||||||
});
|
});
|
||||||
|
protocol_->OnAudioChannelOpened([this, codec]() {
|
||||||
|
if (protocol_->GetServerSampleRate() != codec->output_sample_rate()) {
|
||||||
|
ESP_LOGW(TAG, "服务器的音频采样率 %d 与设备输出的采样率 %d 不一致,重采样后可能会失真",
|
||||||
|
protocol_->GetServerSampleRate(), codec->output_sample_rate());
|
||||||
|
}
|
||||||
|
SetDecodeSampleRate(protocol_->GetServerSampleRate());
|
||||||
|
});
|
||||||
protocol_->OnAudioChannelClosed([this]() {
|
protocol_->OnAudioChannelClosed([this]() {
|
||||||
Schedule([this]() {
|
Schedule([this]() {
|
||||||
SetChatState(kChatStateIdle);
|
SetChatState(kChatStateIdle);
|
||||||
@ -289,7 +301,9 @@ void Application::Start() {
|
|||||||
Schedule([this]() {
|
Schedule([this]() {
|
||||||
auto codec = Board::GetInstance().GetAudioCodec();
|
auto codec = Board::GetInstance().GetAudioCodec();
|
||||||
codec->WaitForOutputDone();
|
codec->WaitForOutputDone();
|
||||||
SetChatState(kChatStateListening);
|
if (chat_state_ == kChatStateSpeaking) {
|
||||||
|
SetChatState(kChatStateListening);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
} else if (strcmp(state->valuestring, "sentence_start") == 0) {
|
} else if (strcmp(state->valuestring, "sentence_start") == 0) {
|
||||||
auto text = cJSON_GetObjectItem(root, "text");
|
auto text = cJSON_GetObjectItem(root, "text");
|
||||||
@ -307,15 +321,6 @@ void Application::Start() {
|
|||||||
if (emotion != NULL) {
|
if (emotion != NULL) {
|
||||||
ESP_LOGD(TAG, "EMOTION: %s", emotion->valuestring);
|
ESP_LOGD(TAG, "EMOTION: %s", emotion->valuestring);
|
||||||
}
|
}
|
||||||
} else if (strcmp(type->valuestring, "hello") == 0) {
|
|
||||||
// Get sample rate from hello message
|
|
||||||
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
|
|
||||||
if (audio_params != NULL) {
|
|
||||||
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
|
|
||||||
if (sample_rate != NULL) {
|
|
||||||
SetDecodeSampleRate(sample_rate->valueint);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -351,8 +356,7 @@ void Application::MainLoop() {
|
|||||||
|
|
||||||
void Application::AbortSpeaking() {
|
void Application::AbortSpeaking() {
|
||||||
ESP_LOGI(TAG, "Abort speaking");
|
ESP_LOGI(TAG, "Abort speaking");
|
||||||
std::string json = "{\"type\":\"abort\"}";
|
protocol_->SendAbort();
|
||||||
protocol_->SendText(json);
|
|
||||||
|
|
||||||
skip_to_end_ = true;
|
skip_to_end_ = true;
|
||||||
auto codec = Board::GetInstance().GetAudioCodec();
|
auto codec = Board::GetInstance().GetAudioCodec();
|
||||||
@ -420,10 +424,7 @@ void Application::SetChatState(ChatState state) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string json = "{\"type\":\"state\",\"state\":\"";
|
protocol_->SendState(state_str[chat_state_]);
|
||||||
json += state_str[chat_state_];
|
|
||||||
json += "\"}";
|
|
||||||
protocol_->SendText(json);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Application::AudioEncodeTask() {
|
void Application::AudioEncodeTask() {
|
||||||
@ -455,7 +456,7 @@ void Application::AudioEncodeTask() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
int frame_size = opus_decode_sample_rate_ * opus_duration_ms_ / 1000;
|
int frame_size = opus_decode_sample_rate_ * OPUS_FRAME_DURATION_MS / 1000;
|
||||||
std::vector<int16_t> pcm(frame_size);
|
std::vector<int16_t> pcm(frame_size);
|
||||||
|
|
||||||
int ret = opus_decode(opus_decoder_, (const unsigned char*)opus.data(), opus.size(), pcm.data(), frame_size, 0);
|
int ret = opus_decode(opus_decoder_, (const unsigned char*)opus.data(), opus.size(), pcm.data(), frame_size, 0);
|
||||||
|
@ -39,6 +39,8 @@ enum ChatState {
|
|||||||
kChatStateUpgrading
|
kChatStateUpgrading
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#define OPUS_FRAME_DURATION_MS 60
|
||||||
|
|
||||||
class Application {
|
class Application {
|
||||||
public:
|
public:
|
||||||
static Application& GetInstance() {
|
static Application& GetInstance() {
|
||||||
@ -84,16 +86,11 @@ private:
|
|||||||
OpusEncoder opus_encoder_;
|
OpusEncoder opus_encoder_;
|
||||||
OpusDecoder* opus_decoder_ = nullptr;
|
OpusDecoder* opus_decoder_ = nullptr;
|
||||||
|
|
||||||
int opus_duration_ms_ = 60;
|
|
||||||
int opus_decode_sample_rate_ = -1;
|
int opus_decode_sample_rate_ = -1;
|
||||||
OpusResampler input_resampler_;
|
OpusResampler input_resampler_;
|
||||||
OpusResampler reference_resampler_;
|
OpusResampler reference_resampler_;
|
||||||
OpusResampler output_resampler_;
|
OpusResampler output_resampler_;
|
||||||
|
|
||||||
TaskHandle_t main_loop_task_ = nullptr;
|
|
||||||
StaticTask_t main_loop_task_buffer_;
|
|
||||||
StackType_t* main_loop_task_stack_ = nullptr;
|
|
||||||
|
|
||||||
void MainLoop();
|
void MainLoop();
|
||||||
void SetDecodeSampleRate(int sample_rate);
|
void SetDecodeSampleRate(int sample_rate);
|
||||||
void CheckNewVersion();
|
void CheckNewVersion();
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#include "audio_codec.h"
|
#include "audio_codec.h"
|
||||||
#include "board.h"
|
#include "board.h"
|
||||||
|
#include "settings.h"
|
||||||
|
|
||||||
#include <esp_log.h>
|
#include <esp_log.h>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
@ -40,6 +41,9 @@ IRAM_ATTR bool AudioCodec::on_sent(i2s_chan_handle_t handle, i2s_event_data_t *e
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AudioCodec::Start() {
|
void AudioCodec::Start() {
|
||||||
|
Settings settings("audio", false);
|
||||||
|
output_volume_ = settings.GetInt("output_volume", output_volume_);
|
||||||
|
|
||||||
// 注册音频输出回调
|
// 注册音频输出回调
|
||||||
i2s_event_callbacks_t callbacks = {};
|
i2s_event_callbacks_t callbacks = {};
|
||||||
callbacks.on_sent = on_sent;
|
callbacks.on_sent = on_sent;
|
||||||
@ -124,6 +128,9 @@ void AudioCodec::ClearOutputQueue() {
|
|||||||
void AudioCodec::SetOutputVolume(int volume) {
|
void AudioCodec::SetOutputVolume(int volume) {
|
||||||
output_volume_ = volume;
|
output_volume_ = volume;
|
||||||
ESP_LOGI(TAG, "Set output volume to %d", output_volume_);
|
ESP_LOGI(TAG, "Set output volume to %d", output_volume_);
|
||||||
|
|
||||||
|
Settings settings("audio", true);
|
||||||
|
settings.SetInt("output_volume", output_volume_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AudioCodec::EnableInput(bool enable) {
|
void AudioCodec::EnableInput(bool enable) {
|
||||||
|
@ -18,6 +18,7 @@ extern "C" void app_main(void)
|
|||||||
// Initialize NVS flash for WiFi configuration
|
// Initialize NVS flash for WiFi configuration
|
||||||
esp_err_t ret = nvs_flash_init();
|
esp_err_t ret = nvs_flash_init();
|
||||||
if (ret == ESP_ERR_NVS_NO_FREE_PAGES || ret == ESP_ERR_NVS_NEW_VERSION_FOUND) {
|
if (ret == ESP_ERR_NVS_NO_FREE_PAGES || ret == ESP_ERR_NVS_NEW_VERSION_FOUND) {
|
||||||
|
ESP_LOGW(TAG, "Erasing NVS flash to fix corruption");
|
||||||
ESP_ERROR_CHECK(nvs_flash_erase());
|
ESP_ERROR_CHECK(nvs_flash_erase());
|
||||||
ret = nvs_flash_init();
|
ret = nvs_flash_init();
|
||||||
}
|
}
|
||||||
|
20
main/ota.cc
20
main/ota.cc
@ -1,6 +1,7 @@
|
|||||||
#include "ota.h"
|
#include "ota.h"
|
||||||
#include "system_info.h"
|
#include "system_info.h"
|
||||||
#include "board.h"
|
#include "board.h"
|
||||||
|
#include "settings.h"
|
||||||
|
|
||||||
#include <cJSON.h>
|
#include <cJSON.h>
|
||||||
#include <esp_log.h>
|
#include <esp_log.h>
|
||||||
@ -34,13 +35,13 @@ void Ota::SetPostData(const std::string& post_data) {
|
|||||||
post_data_ = post_data;
|
post_data_ = post_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ota::CheckVersion() {
|
bool Ota::CheckVersion() {
|
||||||
std::string current_version = esp_app_get_description()->version;
|
std::string current_version = esp_app_get_description()->version;
|
||||||
ESP_LOGI(TAG, "Current version: %s", current_version.c_str());
|
ESP_LOGI(TAG, "Current version: %s", current_version.c_str());
|
||||||
|
|
||||||
if (check_version_url_.length() < 10) {
|
if (check_version_url_.length() < 10) {
|
||||||
ESP_LOGE(TAG, "Check version URL is not properly set");
|
ESP_LOGE(TAG, "Check version URL is not properly set");
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto http = Board::GetInstance().CreateHttp();
|
auto http = Board::GetInstance().CreateHttp();
|
||||||
@ -67,16 +68,18 @@ void Ota::CheckVersion() {
|
|||||||
cJSON *root = cJSON_Parse(response.c_str());
|
cJSON *root = cJSON_Parse(response.c_str());
|
||||||
if (root == NULL) {
|
if (root == NULL) {
|
||||||
ESP_LOGE(TAG, "Failed to parse JSON response");
|
ESP_LOGE(TAG, "Failed to parse JSON response");
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
cJSON *mqtt = cJSON_GetObjectItem(root, "mqtt");
|
cJSON *mqtt = cJSON_GetObjectItem(root, "mqtt");
|
||||||
if (mqtt != NULL) {
|
if (mqtt != NULL) {
|
||||||
|
Settings settings("mqtt", true);
|
||||||
cJSON *item = NULL;
|
cJSON *item = NULL;
|
||||||
cJSON_ArrayForEach(item, mqtt) {
|
cJSON_ArrayForEach(item, mqtt) {
|
||||||
if (item->type == cJSON_String) {
|
if (item->type == cJSON_String) {
|
||||||
mqtt_config_[item->string] = item->valuestring;
|
if (settings.GetString(item->string) != item->valuestring) {
|
||||||
ESP_LOGI(TAG, "MQTT config: %s = %s", item->string, item->valuestring);
|
settings.SetString(item->string, item->valuestring);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
has_mqtt_config_ = true;
|
has_mqtt_config_ = true;
|
||||||
@ -86,19 +89,19 @@ void Ota::CheckVersion() {
|
|||||||
if (firmware == NULL) {
|
if (firmware == NULL) {
|
||||||
ESP_LOGE(TAG, "Failed to get firmware object");
|
ESP_LOGE(TAG, "Failed to get firmware object");
|
||||||
cJSON_Delete(root);
|
cJSON_Delete(root);
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
cJSON *version = cJSON_GetObjectItem(firmware, "version");
|
cJSON *version = cJSON_GetObjectItem(firmware, "version");
|
||||||
if (version == NULL) {
|
if (version == NULL) {
|
||||||
ESP_LOGE(TAG, "Failed to get version object");
|
ESP_LOGE(TAG, "Failed to get version object");
|
||||||
cJSON_Delete(root);
|
cJSON_Delete(root);
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
cJSON *url = cJSON_GetObjectItem(firmware, "url");
|
cJSON *url = cJSON_GetObjectItem(firmware, "url");
|
||||||
if (url == NULL) {
|
if (url == NULL) {
|
||||||
ESP_LOGE(TAG, "Failed to get url object");
|
ESP_LOGE(TAG, "Failed to get url object");
|
||||||
cJSON_Delete(root);
|
cJSON_Delete(root);
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
firmware_version_ = version->valuestring;
|
firmware_version_ = version->valuestring;
|
||||||
@ -112,6 +115,7 @@ void Ota::CheckVersion() {
|
|||||||
} else {
|
} else {
|
||||||
ESP_LOGI(TAG, "Current is the latest version");
|
ESP_LOGI(TAG, "Current is the latest version");
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ota::MarkCurrentVersionValid() {
|
void Ota::MarkCurrentVersionValid() {
|
||||||
|
@ -13,14 +13,12 @@ public:
|
|||||||
void SetCheckVersionUrl(std::string check_version_url);
|
void SetCheckVersionUrl(std::string check_version_url);
|
||||||
void SetHeader(const std::string& key, const std::string& value);
|
void SetHeader(const std::string& key, const std::string& value);
|
||||||
void SetPostData(const std::string& post_data);
|
void SetPostData(const std::string& post_data);
|
||||||
void CheckVersion();
|
bool CheckVersion();
|
||||||
bool HasNewVersion() { return has_new_version_; }
|
bool HasNewVersion() { return has_new_version_; }
|
||||||
bool HasMqttConfig() { return has_mqtt_config_; }
|
bool HasMqttConfig() { return has_mqtt_config_; }
|
||||||
void StartUpgrade(std::function<void(int progress, size_t speed)> callback);
|
void StartUpgrade(std::function<void(int progress, size_t speed)> callback);
|
||||||
void MarkCurrentVersionValid();
|
void MarkCurrentVersionValid();
|
||||||
|
|
||||||
std::map<std::string, std::string>& GetMqttConfig() { return mqtt_config_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string check_version_url_;
|
std::string check_version_url_;
|
||||||
bool has_new_version_ = false;
|
bool has_new_version_ = false;
|
||||||
@ -29,7 +27,6 @@ private:
|
|||||||
std::string firmware_url_;
|
std::string firmware_url_;
|
||||||
std::string post_data_;
|
std::string post_data_;
|
||||||
std::map<std::string, std::string> headers_;
|
std::map<std::string, std::string> headers_;
|
||||||
std::map<std::string, std::string> mqtt_config_;
|
|
||||||
|
|
||||||
void Upgrade(const std::string& firmware_url);
|
void Upgrade(const std::string& firmware_url);
|
||||||
std::function<void(int progress, size_t speed)> upgrade_callback_;
|
std::function<void(int progress, size_t speed)> upgrade_callback_;
|
||||||
|
@ -13,10 +13,14 @@ public:
|
|||||||
virtual void OnIncomingJson(std::function<void(const cJSON* root)> callback) = 0;
|
virtual void OnIncomingJson(std::function<void(const cJSON* root)> callback) = 0;
|
||||||
virtual void SendAudio(const std::string& data) = 0;
|
virtual void SendAudio(const std::string& data) = 0;
|
||||||
virtual void SendText(const std::string& text) = 0;
|
virtual void SendText(const std::string& text) = 0;
|
||||||
|
virtual void SendState(const std::string& state) = 0;
|
||||||
|
virtual void SendAbort() = 0;
|
||||||
virtual bool OpenAudioChannel() = 0;
|
virtual bool OpenAudioChannel() = 0;
|
||||||
virtual void CloseAudioChannel() = 0;
|
virtual void CloseAudioChannel() = 0;
|
||||||
virtual void OnAudioChannelOpened(std::function<void()> callback) = 0;
|
virtual void OnAudioChannelOpened(std::function<void()> callback) = 0;
|
||||||
virtual void OnAudioChannelClosed(std::function<void()> callback) = 0;
|
virtual void OnAudioChannelClosed(std::function<void()> callback) = 0;
|
||||||
|
virtual bool IsAudioChannelOpened() const = 0;
|
||||||
|
virtual int GetServerSampleRate() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // PROTOCOL_H
|
#endif // PROTOCOL_H
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
#include "mqtt_protocol.h"
|
#include "mqtt_protocol.h"
|
||||||
#include "board.h"
|
#include "board.h"
|
||||||
|
#include "application.h"
|
||||||
|
#include "settings.h"
|
||||||
|
|
||||||
#include <esp_log.h>
|
#include <esp_log.h>
|
||||||
#include <ml307_mqtt.h>
|
#include <ml307_mqtt.h>
|
||||||
@ -9,16 +11,9 @@
|
|||||||
|
|
||||||
#define TAG "MQTT"
|
#define TAG "MQTT"
|
||||||
|
|
||||||
MqttProtocol::MqttProtocol(std::map<std::string, std::string>& config) {
|
MqttProtocol::MqttProtocol() {
|
||||||
event_group_handle_ = xEventGroupCreate();
|
event_group_handle_ = xEventGroupCreate();
|
||||||
|
|
||||||
endpoint_ = config["endpoint"];
|
|
||||||
client_id_ = config["client_id"];
|
|
||||||
username_ = config["username"];
|
|
||||||
password_ = config["password"];
|
|
||||||
subscribe_topic_ = config["subscribe_topic"];
|
|
||||||
publish_topic_ = config["publish_topic"];
|
|
||||||
|
|
||||||
StartMqttClient();
|
StartMqttClient();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,6 +34,19 @@ bool MqttProtocol::StartMqttClient() {
|
|||||||
delete mqtt_;
|
delete mqtt_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Settings settings("mqtt", false);
|
||||||
|
endpoint_ = settings.GetString("endpoint");
|
||||||
|
client_id_ = settings.GetString("client_id");
|
||||||
|
username_ = settings.GetString("username");
|
||||||
|
password_ = settings.GetString("password");
|
||||||
|
subscribe_topic_ = settings.GetString("subscribe_topic");
|
||||||
|
publish_topic_ = settings.GetString("publish_topic");
|
||||||
|
|
||||||
|
if (endpoint_.empty()) {
|
||||||
|
ESP_LOGE(TAG, "MQTT endpoint is not specified");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
mqtt_ = Board::GetInstance().CreateMqtt();
|
mqtt_ = Board::GetInstance().CreateMqtt();
|
||||||
mqtt_->SetKeepAlive(90);
|
mqtt_->SetKeepAlive(90);
|
||||||
|
|
||||||
@ -58,9 +66,7 @@ bool MqttProtocol::StartMqttClient() {
|
|||||||
cJSON_Delete(root);
|
cJSON_Delete(root);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (on_incoming_json_ != nullptr) {
|
|
||||||
on_incoming_json_(root);
|
|
||||||
}
|
|
||||||
if (strcmp(type->valuestring, "hello") == 0) {
|
if (strcmp(type->valuestring, "hello") == 0) {
|
||||||
ParseServerHello(root);
|
ParseServerHello(root);
|
||||||
} else if (strcmp(type->valuestring, "goodbye") == 0) {
|
} else if (strcmp(type->valuestring, "goodbye") == 0) {
|
||||||
@ -70,6 +76,8 @@ bool MqttProtocol::StartMqttClient() {
|
|||||||
on_audio_channel_closed_();
|
on_audio_channel_closed_();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (on_incoming_json_ != nullptr) {
|
||||||
|
on_incoming_json_(root);
|
||||||
}
|
}
|
||||||
cJSON_Delete(root);
|
cJSON_Delete(root);
|
||||||
});
|
});
|
||||||
@ -89,13 +97,17 @@ bool MqttProtocol::StartMqttClient() {
|
|||||||
|
|
||||||
void MqttProtocol::SendText(const std::string& text) {
|
void MqttProtocol::SendText(const std::string& text) {
|
||||||
if (publish_topic_.empty()) {
|
if (publish_topic_.empty()) {
|
||||||
ESP_LOGE(TAG, "Publish topic is not specified");
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
mqtt_->Publish(publish_topic_, text);
|
mqtt_->Publish(publish_topic_, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MqttProtocol::SendAudio(const std::string& data) {
|
void MqttProtocol::SendAudio(const std::string& data) {
|
||||||
|
std::lock_guard<std::mutex> lock(channel_mutex_);
|
||||||
|
if (udp_ == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::string nonce(aes_nonce_);
|
std::string nonce(aes_nonce_);
|
||||||
*(uint16_t*)&nonce[2] = htons(data.size());
|
*(uint16_t*)&nonce[2] = htons(data.size());
|
||||||
*(uint32_t*)&nonce[12] = htonl(++local_sequence_);
|
*(uint32_t*)&nonce[12] = htonl(++local_sequence_);
|
||||||
@ -111,14 +123,26 @@ void MqttProtocol::SendAudio(const std::string& data) {
|
|||||||
ESP_LOGE(TAG, "Failed to encrypt audio data");
|
ESP_LOGE(TAG, "Failed to encrypt audio data");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(channel_mutex_);
|
|
||||||
if (udp_ == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
udp_->Send(encrypted);
|
udp_->Send(encrypted);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MqttProtocol::SendState(const std::string& state) {
|
||||||
|
std::string message = "{";
|
||||||
|
message += "\"session_id\":\"" + session_id_ + "\",";
|
||||||
|
message += "\"type\":\"state\",";
|
||||||
|
message += "\"state\":\"" + state + "\"";
|
||||||
|
message += "}";
|
||||||
|
SendText(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MqttProtocol::SendAbort() {
|
||||||
|
std::string message = "{";
|
||||||
|
message += "\"session_id\":\"" + session_id_ + "\",";
|
||||||
|
message += "\"type\":\"abort\"";
|
||||||
|
message += "}";
|
||||||
|
SendText(message);
|
||||||
|
}
|
||||||
|
|
||||||
void MqttProtocol::CloseAudioChannel() {
|
void MqttProtocol::CloseAudioChannel() {
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(channel_mutex_);
|
std::lock_guard<std::mutex> lock(channel_mutex_);
|
||||||
@ -129,6 +153,7 @@ void MqttProtocol::CloseAudioChannel() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string message = "{";
|
std::string message = "{";
|
||||||
|
message += "\"session_id\":\"" + session_id_ + "\",";
|
||||||
message += "\"type\":\"goodbye\"";
|
message += "\"type\":\"goodbye\"";
|
||||||
message += "}";
|
message += "}";
|
||||||
SendText(message);
|
SendText(message);
|
||||||
@ -139,8 +164,8 @@ void MqttProtocol::CloseAudioChannel() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool MqttProtocol::OpenAudioChannel() {
|
bool MqttProtocol::OpenAudioChannel() {
|
||||||
if (!mqtt_->IsConnected()) {
|
if (mqtt_ == nullptr || !mqtt_->IsConnected()) {
|
||||||
ESP_LOGE(TAG, "MQTT is not connected, try to connect now");
|
ESP_LOGI(TAG, "MQTT is not connected, try to connect now");
|
||||||
if (!StartMqttClient()) {
|
if (!StartMqttClient()) {
|
||||||
ESP_LOGE(TAG, "Failed to connect to MQTT");
|
ESP_LOGE(TAG, "Failed to connect to MQTT");
|
||||||
return false;
|
return false;
|
||||||
@ -155,7 +180,7 @@ bool MqttProtocol::OpenAudioChannel() {
|
|||||||
message += "\"version\": 3,";
|
message += "\"version\": 3,";
|
||||||
message += "\"transport\":\"udp\",";
|
message += "\"transport\":\"udp\",";
|
||||||
message += "\"audio_params\":{";
|
message += "\"audio_params\":{";
|
||||||
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":60";
|
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS);
|
||||||
message += "}}";
|
message += "}}";
|
||||||
SendText(message);
|
SendText(message);
|
||||||
|
|
||||||
@ -185,6 +210,9 @@ bool MqttProtocol::OpenAudioChannel() {
|
|||||||
ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
|
ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (sequence != remote_sequence_ + 1) {
|
||||||
|
ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1);
|
||||||
|
}
|
||||||
|
|
||||||
std::string decrypted;
|
std::string decrypted;
|
||||||
size_t decrypted_size = data.size() - aes_nonce_.size();
|
size_t decrypted_size = data.size() - aes_nonce_.size();
|
||||||
@ -240,6 +268,15 @@ void MqttProtocol::ParseServerHello(const cJSON* root) {
|
|||||||
session_id_ = session_id->valuestring;
|
session_id_ = session_id->valuestring;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get sample rate from hello message
|
||||||
|
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
|
||||||
|
if (audio_params != NULL) {
|
||||||
|
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
|
||||||
|
if (sample_rate != NULL) {
|
||||||
|
server_sample_rate_ = sample_rate->valueint;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto udp = cJSON_GetObjectItem(root, "udp");
|
auto udp = cJSON_GetObjectItem(root, "udp");
|
||||||
if (udp == nullptr) {
|
if (udp == nullptr) {
|
||||||
ESP_LOGE(TAG, "UDP is not specified");
|
ESP_LOGE(TAG, "UDP is not specified");
|
||||||
@ -260,6 +297,10 @@ void MqttProtocol::ParseServerHello(const cJSON* root) {
|
|||||||
xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
|
xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int MqttProtocol::GetServerSampleRate() const {
|
||||||
|
return server_sample_rate_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static const char hex_chars[] = "0123456789ABCDEF";
|
static const char hex_chars[] = "0123456789ABCDEF";
|
||||||
// 辅助函数,将单个十六进制字符转换为对应的数值
|
// 辅助函数,将单个十六进制字符转换为对应的数值
|
||||||
@ -279,3 +320,7 @@ std::string MqttProtocol::DecodeHexString(const std::string& hex_string) {
|
|||||||
}
|
}
|
||||||
return decoded;
|
return decoded;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool MqttProtocol::IsAudioChannelOpened() const {
|
||||||
|
return udp_ != nullptr;
|
||||||
|
}
|
||||||
|
@ -21,17 +21,21 @@
|
|||||||
|
|
||||||
class MqttProtocol : public Protocol {
|
class MqttProtocol : public Protocol {
|
||||||
public:
|
public:
|
||||||
MqttProtocol(std::map<std::string, std::string>& config);
|
MqttProtocol();
|
||||||
~MqttProtocol();
|
~MqttProtocol();
|
||||||
|
|
||||||
void OnIncomingAudio(std::function<void(const std::string& data)> callback);
|
void OnIncomingAudio(std::function<void(const std::string& data)> callback);
|
||||||
void OnIncomingJson(std::function<void(const cJSON* root)> callback);
|
void OnIncomingJson(std::function<void(const cJSON* root)> callback);
|
||||||
void SendAudio(const std::string& data);
|
void SendAudio(const std::string& data);
|
||||||
void SendText(const std::string& text);
|
void SendText(const std::string& text);
|
||||||
|
void SendState(const std::string& state);
|
||||||
|
void SendAbort();
|
||||||
bool OpenAudioChannel();
|
bool OpenAudioChannel();
|
||||||
void CloseAudioChannel();
|
void CloseAudioChannel();
|
||||||
void OnAudioChannelOpened(std::function<void()> callback);
|
void OnAudioChannelOpened(std::function<void()> callback);
|
||||||
void OnAudioChannelClosed(std::function<void()> callback);
|
void OnAudioChannelClosed(std::function<void()> callback);
|
||||||
|
bool IsAudioChannelOpened() const;
|
||||||
|
int GetServerSampleRate() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EventGroupHandle_t event_group_handle_;
|
EventGroupHandle_t event_group_handle_;
|
||||||
@ -58,6 +62,7 @@ private:
|
|||||||
uint32_t local_sequence_;
|
uint32_t local_sequence_;
|
||||||
uint32_t remote_sequence_;
|
uint32_t remote_sequence_;
|
||||||
std::string session_id_;
|
std::string session_id_;
|
||||||
|
int server_sample_rate_ = 16000;
|
||||||
|
|
||||||
bool StartMqttClient();
|
bool StartMqttClient();
|
||||||
void ParseServerHello(const cJSON* root);
|
void ParseServerHello(const cJSON* root);
|
||||||
|
63
main/settings.cc
Normal file
63
main/settings.cc
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
#include "settings.h"
|
||||||
|
|
||||||
|
#include <esp_log.h>
|
||||||
|
#include <nvs_flash.h>
|
||||||
|
|
||||||
|
#define TAG "Settings"
|
||||||
|
|
||||||
|
Settings::Settings(const std::string& ns, bool read_write) : ns_(ns), read_write_(read_write) {
|
||||||
|
nvs_open(ns.c_str(), read_write_ ? NVS_READWRITE : NVS_READONLY, &nvs_handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Settings::~Settings() {
|
||||||
|
if (nvs_handle_ != 0) {
|
||||||
|
if (read_write_) {
|
||||||
|
ESP_ERROR_CHECK(nvs_commit(nvs_handle_));
|
||||||
|
}
|
||||||
|
nvs_close(nvs_handle_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Settings::GetString(const std::string& key, const std::string& default_value) {
|
||||||
|
if (nvs_handle_ == 0) {
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t length = 0;
|
||||||
|
if (nvs_get_str(nvs_handle_, key.c_str(), nullptr, &length) != ESP_OK) {
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string value;
|
||||||
|
value.resize(length);
|
||||||
|
ESP_ERROR_CHECK(nvs_get_str(nvs_handle_, key.c_str(), value.data(), &length));
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Settings::SetString(const std::string& key, const std::string& value) {
|
||||||
|
if (read_write_) {
|
||||||
|
ESP_ERROR_CHECK(nvs_set_str(nvs_handle_, key.c_str(), value.c_str()));
|
||||||
|
} else {
|
||||||
|
ESP_LOGW(TAG, "Namespace %s is not open for writing", ns_.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t Settings::GetInt(const std::string& key, int32_t default_value) {
|
||||||
|
if (nvs_handle_ == 0) {
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t value;
|
||||||
|
if (nvs_get_i32(nvs_handle_, key.c_str(), &value) != ESP_OK) {
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Settings::SetInt(const std::string& key, int32_t value) {
|
||||||
|
if (read_write_) {
|
||||||
|
ESP_ERROR_CHECK(nvs_set_i32(nvs_handle_, key.c_str(), value));
|
||||||
|
} else {
|
||||||
|
ESP_LOGW(TAG, "Namespace %s is not open for writing", ns_.c_str());
|
||||||
|
}
|
||||||
|
}
|
23
main/settings.h
Normal file
23
main/settings.h
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
#ifndef SETTINGS_H
|
||||||
|
#define SETTINGS_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <nvs_flash.h>
|
||||||
|
|
||||||
|
class Settings {
|
||||||
|
public:
|
||||||
|
Settings(const std::string& ns, bool read_write = false);
|
||||||
|
~Settings();
|
||||||
|
|
||||||
|
std::string GetString(const std::string& key, const std::string& default_value = "");
|
||||||
|
void SetString(const std::string& key, const std::string& value);
|
||||||
|
int32_t GetInt(const std::string& key, int32_t default_value = 0);
|
||||||
|
void SetInt(const std::string& key, int32_t value);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string ns_;
|
||||||
|
nvs_handle_t nvs_handle_ = 0;
|
||||||
|
bool read_write_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
@ -170,6 +170,7 @@ void WakeWordDetect::StoreWakeWordData(uint16_t* data, size_t samples) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void WakeWordDetect::EncodeWakeWordData() {
|
void WakeWordDetect::EncodeWakeWordData() {
|
||||||
|
xEventGroupClearBits(event_group_, WAKE_WORD_ENCODED_EVENT);
|
||||||
wake_word_opus_.clear();
|
wake_word_opus_.clear();
|
||||||
if (wake_word_encode_task_stack_ == nullptr) {
|
if (wake_word_encode_task_stack_ == nullptr) {
|
||||||
wake_word_encode_task_stack_ = (StackType_t*)malloc(4096 * 8);
|
wake_word_encode_task_stack_ = (StackType_t*)malloc(4096 * 8);
|
||||||
@ -192,7 +193,7 @@ void WakeWordDetect::EncodeWakeWordData() {
|
|||||||
this_->wake_word_pcm_.clear();
|
this_->wake_word_pcm_.clear();
|
||||||
|
|
||||||
auto end_time = esp_timer_get_time();
|
auto end_time = esp_timer_get_time();
|
||||||
ESP_LOGI(TAG, "Encode wake word opus: %zu bytes in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000);
|
ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000);
|
||||||
xEventGroupSetBits(this_->event_group_, WAKE_WORD_ENCODED_EVENT);
|
xEventGroupSetBits(this_->event_group_, WAKE_WORD_ENCODED_EVENT);
|
||||||
this_->wake_word_cv_.notify_one();
|
this_->wake_word_cv_.notify_one();
|
||||||
delete encoder;
|
delete encoder;
|
||||||
|
@ -5,7 +5,6 @@ CONFIG_BOOTLOADER_APP_ROLLBACK_ENABLE=y
|
|||||||
|
|
||||||
CONFIG_HTTPD_MAX_REQ_HDR_LEN=2048
|
CONFIG_HTTPD_MAX_REQ_HDR_LEN=2048
|
||||||
CONFIG_HTTPD_MAX_URI_LEN=2048
|
CONFIG_HTTPD_MAX_URI_LEN=2048
|
||||||
CONFIG_ESP_MAIN_TASK_STACK_SIZE=8192
|
|
||||||
|
|
||||||
CONFIG_PARTITION_TABLE_CUSTOM=y
|
CONFIG_PARTITION_TABLE_CUSTOM=y
|
||||||
CONFIG_PARTITION_TABLE_CUSTOM_FILENAME="partitions.csv"
|
CONFIG_PARTITION_TABLE_CUSTOM_FILENAME="partitions.csv"
|
||||||
|
Reference in New Issue
Block a user