add settings

This commit is contained in:
Terrence
2024-11-15 04:44:53 +08:00
parent ec918748f1
commit 58de3852c5
15 changed files with 237 additions and 89 deletions

View File

@ -4,7 +4,7 @@
# CMakeLists in this exact order for cmake to work correctly # CMakeLists in this exact order for cmake to work correctly
cmake_minimum_required(VERSION 3.16) cmake_minimum_required(VERSION 3.16)
set(PROJECT_VER "0.8.0") set(PROJECT_VER "0.8.1")
include($ENV{IDF_PATH}/tools/cmake/project.cmake) include($ENV{IDF_PATH}/tools/cmake/project.cmake)
project(xiaozhi) project(xiaozhi)

View File

@ -16,6 +16,7 @@ set(SOURCES "audio_codec.cc"
"button.cc" "button.cc"
"led.cc" "led.cc"
"ota.cc" "ota.cc"
"settings.cc"
"main.cc" "main.cc"
) )
set(INCLUDE_DIRS ".") set(INCLUDE_DIRS ".")

View File

@ -37,9 +37,6 @@ Application::~Application() {
if (audio_encode_task_stack_ != nullptr) { if (audio_encode_task_stack_ != nullptr) {
heap_caps_free(audio_encode_task_stack_); heap_caps_free(audio_encode_task_stack_);
} }
if (main_loop_task_stack_ != nullptr) {
heap_caps_free(main_loop_task_stack_);
}
vEventGroupDelete(event_group_); vEventGroupDelete(event_group_);
} }
@ -47,20 +44,34 @@ Application::~Application() {
void Application::CheckNewVersion() { void Application::CheckNewVersion() {
// Check if there is a new firmware version available // Check if there is a new firmware version available
ota_.SetPostData(Board::GetInstance().GetJson()); ota_.SetPostData(Board::GetInstance().GetJson());
ota_.CheckVersion();
if (ota_.HasNewVersion()) { while (true) {
SetChatState(kChatStateUpgrading); if (ota_.CheckVersion()) {
ota_.StartUpgrade([](int progress, size_t speed) { if (ota_.HasNewVersion()) {
char buffer[64]; // Wait for the chat state to be idle
snprintf(buffer, sizeof(buffer), "Upgrading...\n %d%% %zuKB/s", progress, speed / 1024); while (chat_state_ != kChatStateIdle) {
auto display = Board::GetInstance().GetDisplay(); vTaskDelay(100);
display->SetText(buffer); }
});
// If upgrade success, the device will reboot and never reach here SetChatState(kChatStateUpgrading);
ESP_LOGI(TAG, "Firmware upgrade failed..."); ota_.StartUpgrade([](int progress, size_t speed) {
SetChatState(kChatStateIdle); char buffer[64];
} else { snprintf(buffer, sizeof(buffer), "Upgrading...\n %d%% %zuKB/s", progress, speed / 1024);
ota_.MarkCurrentVersionValid(); auto display = Board::GetInstance().GetDisplay();
display->SetText(buffer);
});
// If upgrade success, the device will reboot and never reach here
ESP_LOGI(TAG, "Firmware upgrade failed...");
SetChatState(kChatStateIdle);
} else {
ota_.MarkCurrentVersionValid();
}
return;
}
// Check again in 60 seconds
vTaskDelay(pdMS_TO_TICKS(60000));
} }
} }
@ -191,24 +202,18 @@ void Application::Start() {
/* Wait for the network to be ready */ /* Wait for the network to be ready */
board.StartNetwork(); board.StartNetwork();
const size_t main_loop_stack_size = 4096 * 8; xTaskCreate([](void* arg) {
main_loop_task_stack_ = (StackType_t*)heap_caps_malloc(main_loop_stack_size, MALLOC_CAP_SPIRAM);
xTaskCreateStatic([](void* arg) {
Application* app = (Application*)arg; Application* app = (Application*)arg;
app->MainLoop(); app->MainLoop();
vTaskDelete(NULL); vTaskDelete(NULL);
}, "main_loop", main_loop_stack_size, this, 1, main_loop_task_stack_, &main_loop_task_buffer_); }, "main_loop", 4096 * 2, this, 1, nullptr);
// Check for new firmware version or get the MQTT broker address // Check for new firmware version or get the MQTT broker address
while (true) { xTaskCreate([](void* arg) {
CheckNewVersion(); Application* app = (Application*)arg;
app->CheckNewVersion();
if (ota_.HasMqttConfig()) { vTaskDelete(NULL);
break; }, "check_new_version", 4096 * 2, this, 1, nullptr);
}
Alert("Error", "Missing MQTT config");
vTaskDelay(pdMS_TO_TICKS(10000));
}
#ifdef CONFIG_USE_AFE_SR #ifdef CONFIG_USE_AFE_SR
audio_processor_.Initialize(codec->input_channels(), codec->input_reference()); audio_processor_.Initialize(codec->input_channels(), codec->input_reference());
@ -264,12 +269,19 @@ void Application::Start() {
// Initialize the protocol // Initialize the protocol
display->SetText("Starting\nProtocol..."); display->SetText("Starting\nProtocol...");
protocol_ = new MqttProtocol(ota_.GetMqttConfig()); protocol_ = new MqttProtocol();
protocol_->OnIncomingAudio([this](const std::string& data) { protocol_->OnIncomingAudio([this](const std::string& data) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
audio_decode_queue_.emplace_back(std::move(data)); audio_decode_queue_.emplace_back(std::move(data));
cv_.notify_all(); cv_.notify_all();
}); });
protocol_->OnAudioChannelOpened([this, codec]() {
if (protocol_->GetServerSampleRate() != codec->output_sample_rate()) {
ESP_LOGW(TAG, "服务器的音频采样率 %d 与设备输出的采样率 %d 不一致,重采样后可能会失真",
protocol_->GetServerSampleRate(), codec->output_sample_rate());
}
SetDecodeSampleRate(protocol_->GetServerSampleRate());
});
protocol_->OnAudioChannelClosed([this]() { protocol_->OnAudioChannelClosed([this]() {
Schedule([this]() { Schedule([this]() {
SetChatState(kChatStateIdle); SetChatState(kChatStateIdle);
@ -289,7 +301,9 @@ void Application::Start() {
Schedule([this]() { Schedule([this]() {
auto codec = Board::GetInstance().GetAudioCodec(); auto codec = Board::GetInstance().GetAudioCodec();
codec->WaitForOutputDone(); codec->WaitForOutputDone();
SetChatState(kChatStateListening); if (chat_state_ == kChatStateSpeaking) {
SetChatState(kChatStateListening);
}
}); });
} else if (strcmp(state->valuestring, "sentence_start") == 0) { } else if (strcmp(state->valuestring, "sentence_start") == 0) {
auto text = cJSON_GetObjectItem(root, "text"); auto text = cJSON_GetObjectItem(root, "text");
@ -307,15 +321,6 @@ void Application::Start() {
if (emotion != NULL) { if (emotion != NULL) {
ESP_LOGD(TAG, "EMOTION: %s", emotion->valuestring); ESP_LOGD(TAG, "EMOTION: %s", emotion->valuestring);
} }
} else if (strcmp(type->valuestring, "hello") == 0) {
// Get sample rate from hello message
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (audio_params != NULL) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (sample_rate != NULL) {
SetDecodeSampleRate(sample_rate->valueint);
}
}
} }
}); });
@ -351,8 +356,7 @@ void Application::MainLoop() {
void Application::AbortSpeaking() { void Application::AbortSpeaking() {
ESP_LOGI(TAG, "Abort speaking"); ESP_LOGI(TAG, "Abort speaking");
std::string json = "{\"type\":\"abort\"}"; protocol_->SendAbort();
protocol_->SendText(json);
skip_to_end_ = true; skip_to_end_ = true;
auto codec = Board::GetInstance().GetAudioCodec(); auto codec = Board::GetInstance().GetAudioCodec();
@ -420,10 +424,7 @@ void Application::SetChatState(ChatState state) {
break; break;
} }
std::string json = "{\"type\":\"state\",\"state\":\""; protocol_->SendState(state_str[chat_state_]);
json += state_str[chat_state_];
json += "\"}";
protocol_->SendText(json);
} }
void Application::AudioEncodeTask() { void Application::AudioEncodeTask() {
@ -455,7 +456,7 @@ void Application::AudioEncodeTask() {
continue; continue;
} }
int frame_size = opus_decode_sample_rate_ * opus_duration_ms_ / 1000; int frame_size = opus_decode_sample_rate_ * OPUS_FRAME_DURATION_MS / 1000;
std::vector<int16_t> pcm(frame_size); std::vector<int16_t> pcm(frame_size);
int ret = opus_decode(opus_decoder_, (const unsigned char*)opus.data(), opus.size(), pcm.data(), frame_size, 0); int ret = opus_decode(opus_decoder_, (const unsigned char*)opus.data(), opus.size(), pcm.data(), frame_size, 0);

View File

@ -39,6 +39,8 @@ enum ChatState {
kChatStateUpgrading kChatStateUpgrading
}; };
#define OPUS_FRAME_DURATION_MS 60
class Application { class Application {
public: public:
static Application& GetInstance() { static Application& GetInstance() {
@ -84,16 +86,11 @@ private:
OpusEncoder opus_encoder_; OpusEncoder opus_encoder_;
OpusDecoder* opus_decoder_ = nullptr; OpusDecoder* opus_decoder_ = nullptr;
int opus_duration_ms_ = 60;
int opus_decode_sample_rate_ = -1; int opus_decode_sample_rate_ = -1;
OpusResampler input_resampler_; OpusResampler input_resampler_;
OpusResampler reference_resampler_; OpusResampler reference_resampler_;
OpusResampler output_resampler_; OpusResampler output_resampler_;
TaskHandle_t main_loop_task_ = nullptr;
StaticTask_t main_loop_task_buffer_;
StackType_t* main_loop_task_stack_ = nullptr;
void MainLoop(); void MainLoop();
void SetDecodeSampleRate(int sample_rate); void SetDecodeSampleRate(int sample_rate);
void CheckNewVersion(); void CheckNewVersion();

View File

@ -1,5 +1,6 @@
#include "audio_codec.h" #include "audio_codec.h"
#include "board.h" #include "board.h"
#include "settings.h"
#include <esp_log.h> #include <esp_log.h>
#include <cstring> #include <cstring>
@ -40,6 +41,9 @@ IRAM_ATTR bool AudioCodec::on_sent(i2s_chan_handle_t handle, i2s_event_data_t *e
} }
void AudioCodec::Start() { void AudioCodec::Start() {
Settings settings("audio", false);
output_volume_ = settings.GetInt("output_volume", output_volume_);
// 注册音频输出回调 // 注册音频输出回调
i2s_event_callbacks_t callbacks = {}; i2s_event_callbacks_t callbacks = {};
callbacks.on_sent = on_sent; callbacks.on_sent = on_sent;
@ -124,6 +128,9 @@ void AudioCodec::ClearOutputQueue() {
void AudioCodec::SetOutputVolume(int volume) { void AudioCodec::SetOutputVolume(int volume) {
output_volume_ = volume; output_volume_ = volume;
ESP_LOGI(TAG, "Set output volume to %d", output_volume_); ESP_LOGI(TAG, "Set output volume to %d", output_volume_);
Settings settings("audio", true);
settings.SetInt("output_volume", output_volume_);
} }
void AudioCodec::EnableInput(bool enable) { void AudioCodec::EnableInput(bool enable) {

View File

@ -18,6 +18,7 @@ extern "C" void app_main(void)
// Initialize NVS flash for WiFi configuration // Initialize NVS flash for WiFi configuration
esp_err_t ret = nvs_flash_init(); esp_err_t ret = nvs_flash_init();
if (ret == ESP_ERR_NVS_NO_FREE_PAGES || ret == ESP_ERR_NVS_NEW_VERSION_FOUND) { if (ret == ESP_ERR_NVS_NO_FREE_PAGES || ret == ESP_ERR_NVS_NEW_VERSION_FOUND) {
ESP_LOGW(TAG, "Erasing NVS flash to fix corruption");
ESP_ERROR_CHECK(nvs_flash_erase()); ESP_ERROR_CHECK(nvs_flash_erase());
ret = nvs_flash_init(); ret = nvs_flash_init();
} }

View File

@ -1,6 +1,7 @@
#include "ota.h" #include "ota.h"
#include "system_info.h" #include "system_info.h"
#include "board.h" #include "board.h"
#include "settings.h"
#include <cJSON.h> #include <cJSON.h>
#include <esp_log.h> #include <esp_log.h>
@ -34,13 +35,13 @@ void Ota::SetPostData(const std::string& post_data) {
post_data_ = post_data; post_data_ = post_data;
} }
void Ota::CheckVersion() { bool Ota::CheckVersion() {
std::string current_version = esp_app_get_description()->version; std::string current_version = esp_app_get_description()->version;
ESP_LOGI(TAG, "Current version: %s", current_version.c_str()); ESP_LOGI(TAG, "Current version: %s", current_version.c_str());
if (check_version_url_.length() < 10) { if (check_version_url_.length() < 10) {
ESP_LOGE(TAG, "Check version URL is not properly set"); ESP_LOGE(TAG, "Check version URL is not properly set");
return; return false;
} }
auto http = Board::GetInstance().CreateHttp(); auto http = Board::GetInstance().CreateHttp();
@ -67,16 +68,18 @@ void Ota::CheckVersion() {
cJSON *root = cJSON_Parse(response.c_str()); cJSON *root = cJSON_Parse(response.c_str());
if (root == NULL) { if (root == NULL) {
ESP_LOGE(TAG, "Failed to parse JSON response"); ESP_LOGE(TAG, "Failed to parse JSON response");
return; return false;
} }
cJSON *mqtt = cJSON_GetObjectItem(root, "mqtt"); cJSON *mqtt = cJSON_GetObjectItem(root, "mqtt");
if (mqtt != NULL) { if (mqtt != NULL) {
Settings settings("mqtt", true);
cJSON *item = NULL; cJSON *item = NULL;
cJSON_ArrayForEach(item, mqtt) { cJSON_ArrayForEach(item, mqtt) {
if (item->type == cJSON_String) { if (item->type == cJSON_String) {
mqtt_config_[item->string] = item->valuestring; if (settings.GetString(item->string) != item->valuestring) {
ESP_LOGI(TAG, "MQTT config: %s = %s", item->string, item->valuestring); settings.SetString(item->string, item->valuestring);
}
} }
} }
has_mqtt_config_ = true; has_mqtt_config_ = true;
@ -86,19 +89,19 @@ void Ota::CheckVersion() {
if (firmware == NULL) { if (firmware == NULL) {
ESP_LOGE(TAG, "Failed to get firmware object"); ESP_LOGE(TAG, "Failed to get firmware object");
cJSON_Delete(root); cJSON_Delete(root);
return; return false;
} }
cJSON *version = cJSON_GetObjectItem(firmware, "version"); cJSON *version = cJSON_GetObjectItem(firmware, "version");
if (version == NULL) { if (version == NULL) {
ESP_LOGE(TAG, "Failed to get version object"); ESP_LOGE(TAG, "Failed to get version object");
cJSON_Delete(root); cJSON_Delete(root);
return; return false;
} }
cJSON *url = cJSON_GetObjectItem(firmware, "url"); cJSON *url = cJSON_GetObjectItem(firmware, "url");
if (url == NULL) { if (url == NULL) {
ESP_LOGE(TAG, "Failed to get url object"); ESP_LOGE(TAG, "Failed to get url object");
cJSON_Delete(root); cJSON_Delete(root);
return; return false;
} }
firmware_version_ = version->valuestring; firmware_version_ = version->valuestring;
@ -112,6 +115,7 @@ void Ota::CheckVersion() {
} else { } else {
ESP_LOGI(TAG, "Current is the latest version"); ESP_LOGI(TAG, "Current is the latest version");
} }
return true;
} }
void Ota::MarkCurrentVersionValid() { void Ota::MarkCurrentVersionValid() {

View File

@ -13,14 +13,12 @@ public:
void SetCheckVersionUrl(std::string check_version_url); void SetCheckVersionUrl(std::string check_version_url);
void SetHeader(const std::string& key, const std::string& value); void SetHeader(const std::string& key, const std::string& value);
void SetPostData(const std::string& post_data); void SetPostData(const std::string& post_data);
void CheckVersion(); bool CheckVersion();
bool HasNewVersion() { return has_new_version_; } bool HasNewVersion() { return has_new_version_; }
bool HasMqttConfig() { return has_mqtt_config_; } bool HasMqttConfig() { return has_mqtt_config_; }
void StartUpgrade(std::function<void(int progress, size_t speed)> callback); void StartUpgrade(std::function<void(int progress, size_t speed)> callback);
void MarkCurrentVersionValid(); void MarkCurrentVersionValid();
std::map<std::string, std::string>& GetMqttConfig() { return mqtt_config_; }
private: private:
std::string check_version_url_; std::string check_version_url_;
bool has_new_version_ = false; bool has_new_version_ = false;
@ -29,7 +27,6 @@ private:
std::string firmware_url_; std::string firmware_url_;
std::string post_data_; std::string post_data_;
std::map<std::string, std::string> headers_; std::map<std::string, std::string> headers_;
std::map<std::string, std::string> mqtt_config_;
void Upgrade(const std::string& firmware_url); void Upgrade(const std::string& firmware_url);
std::function<void(int progress, size_t speed)> upgrade_callback_; std::function<void(int progress, size_t speed)> upgrade_callback_;

View File

@ -13,10 +13,14 @@ public:
virtual void OnIncomingJson(std::function<void(const cJSON* root)> callback) = 0; virtual void OnIncomingJson(std::function<void(const cJSON* root)> callback) = 0;
virtual void SendAudio(const std::string& data) = 0; virtual void SendAudio(const std::string& data) = 0;
virtual void SendText(const std::string& text) = 0; virtual void SendText(const std::string& text) = 0;
virtual void SendState(const std::string& state) = 0;
virtual void SendAbort() = 0;
virtual bool OpenAudioChannel() = 0; virtual bool OpenAudioChannel() = 0;
virtual void CloseAudioChannel() = 0; virtual void CloseAudioChannel() = 0;
virtual void OnAudioChannelOpened(std::function<void()> callback) = 0; virtual void OnAudioChannelOpened(std::function<void()> callback) = 0;
virtual void OnAudioChannelClosed(std::function<void()> callback) = 0; virtual void OnAudioChannelClosed(std::function<void()> callback) = 0;
virtual bool IsAudioChannelOpened() const = 0;
virtual int GetServerSampleRate() const = 0;
}; };
#endif // PROTOCOL_H #endif // PROTOCOL_H

View File

@ -1,5 +1,7 @@
#include "mqtt_protocol.h" #include "mqtt_protocol.h"
#include "board.h" #include "board.h"
#include "application.h"
#include "settings.h"
#include <esp_log.h> #include <esp_log.h>
#include <ml307_mqtt.h> #include <ml307_mqtt.h>
@ -9,16 +11,9 @@
#define TAG "MQTT" #define TAG "MQTT"
MqttProtocol::MqttProtocol(std::map<std::string, std::string>& config) { MqttProtocol::MqttProtocol() {
event_group_handle_ = xEventGroupCreate(); event_group_handle_ = xEventGroupCreate();
endpoint_ = config["endpoint"];
client_id_ = config["client_id"];
username_ = config["username"];
password_ = config["password"];
subscribe_topic_ = config["subscribe_topic"];
publish_topic_ = config["publish_topic"];
StartMqttClient(); StartMqttClient();
} }
@ -39,6 +34,19 @@ bool MqttProtocol::StartMqttClient() {
delete mqtt_; delete mqtt_;
} }
Settings settings("mqtt", false);
endpoint_ = settings.GetString("endpoint");
client_id_ = settings.GetString("client_id");
username_ = settings.GetString("username");
password_ = settings.GetString("password");
subscribe_topic_ = settings.GetString("subscribe_topic");
publish_topic_ = settings.GetString("publish_topic");
if (endpoint_.empty()) {
ESP_LOGE(TAG, "MQTT endpoint is not specified");
return false;
}
mqtt_ = Board::GetInstance().CreateMqtt(); mqtt_ = Board::GetInstance().CreateMqtt();
mqtt_->SetKeepAlive(90); mqtt_->SetKeepAlive(90);
@ -58,9 +66,7 @@ bool MqttProtocol::StartMqttClient() {
cJSON_Delete(root); cJSON_Delete(root);
return; return;
} }
if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
}
if (strcmp(type->valuestring, "hello") == 0) { if (strcmp(type->valuestring, "hello") == 0) {
ParseServerHello(root); ParseServerHello(root);
} else if (strcmp(type->valuestring, "goodbye") == 0) { } else if (strcmp(type->valuestring, "goodbye") == 0) {
@ -70,6 +76,8 @@ bool MqttProtocol::StartMqttClient() {
on_audio_channel_closed_(); on_audio_channel_closed_();
} }
} }
} else if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
} }
cJSON_Delete(root); cJSON_Delete(root);
}); });
@ -89,13 +97,17 @@ bool MqttProtocol::StartMqttClient() {
void MqttProtocol::SendText(const std::string& text) { void MqttProtocol::SendText(const std::string& text) {
if (publish_topic_.empty()) { if (publish_topic_.empty()) {
ESP_LOGE(TAG, "Publish topic is not specified");
return; return;
} }
mqtt_->Publish(publish_topic_, text); mqtt_->Publish(publish_topic_, text);
} }
void MqttProtocol::SendAudio(const std::string& data) { void MqttProtocol::SendAudio(const std::string& data) {
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ == nullptr) {
return;
}
std::string nonce(aes_nonce_); std::string nonce(aes_nonce_);
*(uint16_t*)&nonce[2] = htons(data.size()); *(uint16_t*)&nonce[2] = htons(data.size());
*(uint32_t*)&nonce[12] = htonl(++local_sequence_); *(uint32_t*)&nonce[12] = htonl(++local_sequence_);
@ -111,14 +123,26 @@ void MqttProtocol::SendAudio(const std::string& data) {
ESP_LOGE(TAG, "Failed to encrypt audio data"); ESP_LOGE(TAG, "Failed to encrypt audio data");
return; return;
} }
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ == nullptr) {
return;
}
udp_->Send(encrypted); udp_->Send(encrypted);
} }
void MqttProtocol::SendState(const std::string& state) {
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"state\",";
message += "\"state\":\"" + state + "\"";
message += "}";
SendText(message);
}
void MqttProtocol::SendAbort() {
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"abort\"";
message += "}";
SendText(message);
}
void MqttProtocol::CloseAudioChannel() { void MqttProtocol::CloseAudioChannel() {
{ {
std::lock_guard<std::mutex> lock(channel_mutex_); std::lock_guard<std::mutex> lock(channel_mutex_);
@ -129,6 +153,7 @@ void MqttProtocol::CloseAudioChannel() {
} }
std::string message = "{"; std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"goodbye\""; message += "\"type\":\"goodbye\"";
message += "}"; message += "}";
SendText(message); SendText(message);
@ -139,8 +164,8 @@ void MqttProtocol::CloseAudioChannel() {
} }
bool MqttProtocol::OpenAudioChannel() { bool MqttProtocol::OpenAudioChannel() {
if (!mqtt_->IsConnected()) { if (mqtt_ == nullptr || !mqtt_->IsConnected()) {
ESP_LOGE(TAG, "MQTT is not connected, try to connect now"); ESP_LOGI(TAG, "MQTT is not connected, try to connect now");
if (!StartMqttClient()) { if (!StartMqttClient()) {
ESP_LOGE(TAG, "Failed to connect to MQTT"); ESP_LOGE(TAG, "Failed to connect to MQTT");
return false; return false;
@ -155,7 +180,7 @@ bool MqttProtocol::OpenAudioChannel() {
message += "\"version\": 3,"; message += "\"version\": 3,";
message += "\"transport\":\"udp\","; message += "\"transport\":\"udp\",";
message += "\"audio_params\":{"; message += "\"audio_params\":{";
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":60"; message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS);
message += "}}"; message += "}}";
SendText(message); SendText(message);
@ -185,6 +210,9 @@ bool MqttProtocol::OpenAudioChannel() {
ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_); ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
return; return;
} }
if (sequence != remote_sequence_ + 1) {
ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1);
}
std::string decrypted; std::string decrypted;
size_t decrypted_size = data.size() - aes_nonce_.size(); size_t decrypted_size = data.size() - aes_nonce_.size();
@ -240,6 +268,15 @@ void MqttProtocol::ParseServerHello(const cJSON* root) {
session_id_ = session_id->valuestring; session_id_ = session_id->valuestring;
} }
// Get sample rate from hello message
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (audio_params != NULL) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (sample_rate != NULL) {
server_sample_rate_ = sample_rate->valueint;
}
}
auto udp = cJSON_GetObjectItem(root, "udp"); auto udp = cJSON_GetObjectItem(root, "udp");
if (udp == nullptr) { if (udp == nullptr) {
ESP_LOGE(TAG, "UDP is not specified"); ESP_LOGE(TAG, "UDP is not specified");
@ -260,6 +297,10 @@ void MqttProtocol::ParseServerHello(const cJSON* root) {
xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT); xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
} }
int MqttProtocol::GetServerSampleRate() const {
return server_sample_rate_;
}
static const char hex_chars[] = "0123456789ABCDEF"; static const char hex_chars[] = "0123456789ABCDEF";
// 辅助函数,将单个十六进制字符转换为对应的数值 // 辅助函数,将单个十六进制字符转换为对应的数值
@ -279,3 +320,7 @@ std::string MqttProtocol::DecodeHexString(const std::string& hex_string) {
} }
return decoded; return decoded;
} }
bool MqttProtocol::IsAudioChannelOpened() const {
return udp_ != nullptr;
}

View File

@ -21,17 +21,21 @@
class MqttProtocol : public Protocol { class MqttProtocol : public Protocol {
public: public:
MqttProtocol(std::map<std::string, std::string>& config); MqttProtocol();
~MqttProtocol(); ~MqttProtocol();
void OnIncomingAudio(std::function<void(const std::string& data)> callback); void OnIncomingAudio(std::function<void(const std::string& data)> callback);
void OnIncomingJson(std::function<void(const cJSON* root)> callback); void OnIncomingJson(std::function<void(const cJSON* root)> callback);
void SendAudio(const std::string& data); void SendAudio(const std::string& data);
void SendText(const std::string& text); void SendText(const std::string& text);
void SendState(const std::string& state);
void SendAbort();
bool OpenAudioChannel(); bool OpenAudioChannel();
void CloseAudioChannel(); void CloseAudioChannel();
void OnAudioChannelOpened(std::function<void()> callback); void OnAudioChannelOpened(std::function<void()> callback);
void OnAudioChannelClosed(std::function<void()> callback); void OnAudioChannelClosed(std::function<void()> callback);
bool IsAudioChannelOpened() const;
int GetServerSampleRate() const;
private: private:
EventGroupHandle_t event_group_handle_; EventGroupHandle_t event_group_handle_;
@ -58,6 +62,7 @@ private:
uint32_t local_sequence_; uint32_t local_sequence_;
uint32_t remote_sequence_; uint32_t remote_sequence_;
std::string session_id_; std::string session_id_;
int server_sample_rate_ = 16000;
bool StartMqttClient(); bool StartMqttClient();
void ParseServerHello(const cJSON* root); void ParseServerHello(const cJSON* root);

63
main/settings.cc Normal file
View File

@ -0,0 +1,63 @@
#include "settings.h"
#include <esp_log.h>
#include <nvs_flash.h>
#define TAG "Settings"
Settings::Settings(const std::string& ns, bool read_write) : ns_(ns), read_write_(read_write) {
nvs_open(ns.c_str(), read_write_ ? NVS_READWRITE : NVS_READONLY, &nvs_handle_);
}
Settings::~Settings() {
if (nvs_handle_ != 0) {
if (read_write_) {
ESP_ERROR_CHECK(nvs_commit(nvs_handle_));
}
nvs_close(nvs_handle_);
}
}
std::string Settings::GetString(const std::string& key, const std::string& default_value) {
if (nvs_handle_ == 0) {
return default_value;
}
size_t length = 0;
if (nvs_get_str(nvs_handle_, key.c_str(), nullptr, &length) != ESP_OK) {
return default_value;
}
std::string value;
value.resize(length);
ESP_ERROR_CHECK(nvs_get_str(nvs_handle_, key.c_str(), value.data(), &length));
return value;
}
void Settings::SetString(const std::string& key, const std::string& value) {
if (read_write_) {
ESP_ERROR_CHECK(nvs_set_str(nvs_handle_, key.c_str(), value.c_str()));
} else {
ESP_LOGW(TAG, "Namespace %s is not open for writing", ns_.c_str());
}
}
int32_t Settings::GetInt(const std::string& key, int32_t default_value) {
if (nvs_handle_ == 0) {
return default_value;
}
int32_t value;
if (nvs_get_i32(nvs_handle_, key.c_str(), &value) != ESP_OK) {
return default_value;
}
return value;
}
void Settings::SetInt(const std::string& key, int32_t value) {
if (read_write_) {
ESP_ERROR_CHECK(nvs_set_i32(nvs_handle_, key.c_str(), value));
} else {
ESP_LOGW(TAG, "Namespace %s is not open for writing", ns_.c_str());
}
}

23
main/settings.h Normal file
View File

@ -0,0 +1,23 @@
#ifndef SETTINGS_H
#define SETTINGS_H
#include <string>
#include <nvs_flash.h>
class Settings {
public:
Settings(const std::string& ns, bool read_write = false);
~Settings();
std::string GetString(const std::string& key, const std::string& default_value = "");
void SetString(const std::string& key, const std::string& value);
int32_t GetInt(const std::string& key, int32_t default_value = 0);
void SetInt(const std::string& key, int32_t value);
private:
std::string ns_;
nvs_handle_t nvs_handle_ = 0;
bool read_write_ = false;
};
#endif

View File

@ -170,6 +170,7 @@ void WakeWordDetect::StoreWakeWordData(uint16_t* data, size_t samples) {
} }
void WakeWordDetect::EncodeWakeWordData() { void WakeWordDetect::EncodeWakeWordData() {
xEventGroupClearBits(event_group_, WAKE_WORD_ENCODED_EVENT);
wake_word_opus_.clear(); wake_word_opus_.clear();
if (wake_word_encode_task_stack_ == nullptr) { if (wake_word_encode_task_stack_ == nullptr) {
wake_word_encode_task_stack_ = (StackType_t*)malloc(4096 * 8); wake_word_encode_task_stack_ = (StackType_t*)malloc(4096 * 8);
@ -192,7 +193,7 @@ void WakeWordDetect::EncodeWakeWordData() {
this_->wake_word_pcm_.clear(); this_->wake_word_pcm_.clear();
auto end_time = esp_timer_get_time(); auto end_time = esp_timer_get_time();
ESP_LOGI(TAG, "Encode wake word opus: %zu bytes in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000); ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000);
xEventGroupSetBits(this_->event_group_, WAKE_WORD_ENCODED_EVENT); xEventGroupSetBits(this_->event_group_, WAKE_WORD_ENCODED_EVENT);
this_->wake_word_cv_.notify_one(); this_->wake_word_cv_.notify_one();
delete encoder; delete encoder;

View File

@ -5,7 +5,6 @@ CONFIG_BOOTLOADER_APP_ROLLBACK_ENABLE=y
CONFIG_HTTPD_MAX_REQ_HDR_LEN=2048 CONFIG_HTTPD_MAX_REQ_HDR_LEN=2048
CONFIG_HTTPD_MAX_URI_LEN=2048 CONFIG_HTTPD_MAX_URI_LEN=2048
CONFIG_ESP_MAIN_TASK_STACK_SIZE=8192
CONFIG_PARTITION_TABLE_CUSTOM=y CONFIG_PARTITION_TABLE_CUSTOM=y
CONFIG_PARTITION_TABLE_CUSTOM_FILENAME="partitions.csv" CONFIG_PARTITION_TABLE_CUSTOM_FILENAME="partitions.csv"