add settings

This commit is contained in:
Terrence
2024-11-15 04:44:53 +08:00
parent ec918748f1
commit 58de3852c5
15 changed files with 237 additions and 89 deletions

View File

@ -4,7 +4,7 @@
# CMakeLists in this exact order for cmake to work correctly
cmake_minimum_required(VERSION 3.16)
set(PROJECT_VER "0.8.0")
set(PROJECT_VER "0.8.1")
include($ENV{IDF_PATH}/tools/cmake/project.cmake)
project(xiaozhi)

View File

@ -16,6 +16,7 @@ set(SOURCES "audio_codec.cc"
"button.cc"
"led.cc"
"ota.cc"
"settings.cc"
"main.cc"
)
set(INCLUDE_DIRS ".")

View File

@ -37,9 +37,6 @@ Application::~Application() {
if (audio_encode_task_stack_ != nullptr) {
heap_caps_free(audio_encode_task_stack_);
}
if (main_loop_task_stack_ != nullptr) {
heap_caps_free(main_loop_task_stack_);
}
vEventGroupDelete(event_group_);
}
@ -47,20 +44,34 @@ Application::~Application() {
void Application::CheckNewVersion() {
// Check if there is a new firmware version available
ota_.SetPostData(Board::GetInstance().GetJson());
ota_.CheckVersion();
if (ota_.HasNewVersion()) {
SetChatState(kChatStateUpgrading);
ota_.StartUpgrade([](int progress, size_t speed) {
char buffer[64];
snprintf(buffer, sizeof(buffer), "Upgrading...\n %d%% %zuKB/s", progress, speed / 1024);
auto display = Board::GetInstance().GetDisplay();
display->SetText(buffer);
});
// If upgrade success, the device will reboot and never reach here
ESP_LOGI(TAG, "Firmware upgrade failed...");
SetChatState(kChatStateIdle);
} else {
ota_.MarkCurrentVersionValid();
while (true) {
if (ota_.CheckVersion()) {
if (ota_.HasNewVersion()) {
// Wait for the chat state to be idle
while (chat_state_ != kChatStateIdle) {
vTaskDelay(100);
}
SetChatState(kChatStateUpgrading);
ota_.StartUpgrade([](int progress, size_t speed) {
char buffer[64];
snprintf(buffer, sizeof(buffer), "Upgrading...\n %d%% %zuKB/s", progress, speed / 1024);
auto display = Board::GetInstance().GetDisplay();
display->SetText(buffer);
});
// If upgrade success, the device will reboot and never reach here
ESP_LOGI(TAG, "Firmware upgrade failed...");
SetChatState(kChatStateIdle);
} else {
ota_.MarkCurrentVersionValid();
}
return;
}
// Check again in 60 seconds
vTaskDelay(pdMS_TO_TICKS(60000));
}
}
@ -191,24 +202,18 @@ void Application::Start() {
/* Wait for the network to be ready */
board.StartNetwork();
const size_t main_loop_stack_size = 4096 * 8;
main_loop_task_stack_ = (StackType_t*)heap_caps_malloc(main_loop_stack_size, MALLOC_CAP_SPIRAM);
xTaskCreateStatic([](void* arg) {
xTaskCreate([](void* arg) {
Application* app = (Application*)arg;
app->MainLoop();
vTaskDelete(NULL);
}, "main_loop", main_loop_stack_size, this, 1, main_loop_task_stack_, &main_loop_task_buffer_);
}, "main_loop", 4096 * 2, this, 1, nullptr);
// Check for new firmware version or get the MQTT broker address
while (true) {
CheckNewVersion();
if (ota_.HasMqttConfig()) {
break;
}
Alert("Error", "Missing MQTT config");
vTaskDelay(pdMS_TO_TICKS(10000));
}
xTaskCreate([](void* arg) {
Application* app = (Application*)arg;
app->CheckNewVersion();
vTaskDelete(NULL);
}, "check_new_version", 4096 * 2, this, 1, nullptr);
#ifdef CONFIG_USE_AFE_SR
audio_processor_.Initialize(codec->input_channels(), codec->input_reference());
@ -264,12 +269,19 @@ void Application::Start() {
// Initialize the protocol
display->SetText("Starting\nProtocol...");
protocol_ = new MqttProtocol(ota_.GetMqttConfig());
protocol_ = new MqttProtocol();
protocol_->OnIncomingAudio([this](const std::string& data) {
std::lock_guard<std::mutex> lock(mutex_);
audio_decode_queue_.emplace_back(std::move(data));
cv_.notify_all();
});
protocol_->OnAudioChannelOpened([this, codec]() {
if (protocol_->GetServerSampleRate() != codec->output_sample_rate()) {
ESP_LOGW(TAG, "服务器的音频采样率 %d 与设备输出的采样率 %d 不一致,重采样后可能会失真",
protocol_->GetServerSampleRate(), codec->output_sample_rate());
}
SetDecodeSampleRate(protocol_->GetServerSampleRate());
});
protocol_->OnAudioChannelClosed([this]() {
Schedule([this]() {
SetChatState(kChatStateIdle);
@ -289,7 +301,9 @@ void Application::Start() {
Schedule([this]() {
auto codec = Board::GetInstance().GetAudioCodec();
codec->WaitForOutputDone();
SetChatState(kChatStateListening);
if (chat_state_ == kChatStateSpeaking) {
SetChatState(kChatStateListening);
}
});
} else if (strcmp(state->valuestring, "sentence_start") == 0) {
auto text = cJSON_GetObjectItem(root, "text");
@ -307,15 +321,6 @@ void Application::Start() {
if (emotion != NULL) {
ESP_LOGD(TAG, "EMOTION: %s", emotion->valuestring);
}
} else if (strcmp(type->valuestring, "hello") == 0) {
// Get sample rate from hello message
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (audio_params != NULL) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (sample_rate != NULL) {
SetDecodeSampleRate(sample_rate->valueint);
}
}
}
});
@ -351,8 +356,7 @@ void Application::MainLoop() {
void Application::AbortSpeaking() {
ESP_LOGI(TAG, "Abort speaking");
std::string json = "{\"type\":\"abort\"}";
protocol_->SendText(json);
protocol_->SendAbort();
skip_to_end_ = true;
auto codec = Board::GetInstance().GetAudioCodec();
@ -420,10 +424,7 @@ void Application::SetChatState(ChatState state) {
break;
}
std::string json = "{\"type\":\"state\",\"state\":\"";
json += state_str[chat_state_];
json += "\"}";
protocol_->SendText(json);
protocol_->SendState(state_str[chat_state_]);
}
void Application::AudioEncodeTask() {
@ -455,7 +456,7 @@ void Application::AudioEncodeTask() {
continue;
}
int frame_size = opus_decode_sample_rate_ * opus_duration_ms_ / 1000;
int frame_size = opus_decode_sample_rate_ * OPUS_FRAME_DURATION_MS / 1000;
std::vector<int16_t> pcm(frame_size);
int ret = opus_decode(opus_decoder_, (const unsigned char*)opus.data(), opus.size(), pcm.data(), frame_size, 0);

View File

@ -39,6 +39,8 @@ enum ChatState {
kChatStateUpgrading
};
#define OPUS_FRAME_DURATION_MS 60
class Application {
public:
static Application& GetInstance() {
@ -84,16 +86,11 @@ private:
OpusEncoder opus_encoder_;
OpusDecoder* opus_decoder_ = nullptr;
int opus_duration_ms_ = 60;
int opus_decode_sample_rate_ = -1;
OpusResampler input_resampler_;
OpusResampler reference_resampler_;
OpusResampler output_resampler_;
TaskHandle_t main_loop_task_ = nullptr;
StaticTask_t main_loop_task_buffer_;
StackType_t* main_loop_task_stack_ = nullptr;
void MainLoop();
void SetDecodeSampleRate(int sample_rate);
void CheckNewVersion();

View File

@ -1,5 +1,6 @@
#include "audio_codec.h"
#include "board.h"
#include "settings.h"
#include <esp_log.h>
#include <cstring>
@ -40,6 +41,9 @@ IRAM_ATTR bool AudioCodec::on_sent(i2s_chan_handle_t handle, i2s_event_data_t *e
}
void AudioCodec::Start() {
Settings settings("audio", false);
output_volume_ = settings.GetInt("output_volume", output_volume_);
// 注册音频输出回调
i2s_event_callbacks_t callbacks = {};
callbacks.on_sent = on_sent;
@ -124,6 +128,9 @@ void AudioCodec::ClearOutputQueue() {
void AudioCodec::SetOutputVolume(int volume) {
output_volume_ = volume;
ESP_LOGI(TAG, "Set output volume to %d", output_volume_);
Settings settings("audio", true);
settings.SetInt("output_volume", output_volume_);
}
void AudioCodec::EnableInput(bool enable) {

View File

@ -18,6 +18,7 @@ extern "C" void app_main(void)
// Initialize NVS flash for WiFi configuration
esp_err_t ret = nvs_flash_init();
if (ret == ESP_ERR_NVS_NO_FREE_PAGES || ret == ESP_ERR_NVS_NEW_VERSION_FOUND) {
ESP_LOGW(TAG, "Erasing NVS flash to fix corruption");
ESP_ERROR_CHECK(nvs_flash_erase());
ret = nvs_flash_init();
}

View File

@ -1,6 +1,7 @@
#include "ota.h"
#include "system_info.h"
#include "board.h"
#include "settings.h"
#include <cJSON.h>
#include <esp_log.h>
@ -34,13 +35,13 @@ void Ota::SetPostData(const std::string& post_data) {
post_data_ = post_data;
}
void Ota::CheckVersion() {
bool Ota::CheckVersion() {
std::string current_version = esp_app_get_description()->version;
ESP_LOGI(TAG, "Current version: %s", current_version.c_str());
if (check_version_url_.length() < 10) {
ESP_LOGE(TAG, "Check version URL is not properly set");
return;
return false;
}
auto http = Board::GetInstance().CreateHttp();
@ -67,16 +68,18 @@ void Ota::CheckVersion() {
cJSON *root = cJSON_Parse(response.c_str());
if (root == NULL) {
ESP_LOGE(TAG, "Failed to parse JSON response");
return;
return false;
}
cJSON *mqtt = cJSON_GetObjectItem(root, "mqtt");
if (mqtt != NULL) {
Settings settings("mqtt", true);
cJSON *item = NULL;
cJSON_ArrayForEach(item, mqtt) {
if (item->type == cJSON_String) {
mqtt_config_[item->string] = item->valuestring;
ESP_LOGI(TAG, "MQTT config: %s = %s", item->string, item->valuestring);
if (settings.GetString(item->string) != item->valuestring) {
settings.SetString(item->string, item->valuestring);
}
}
}
has_mqtt_config_ = true;
@ -86,19 +89,19 @@ void Ota::CheckVersion() {
if (firmware == NULL) {
ESP_LOGE(TAG, "Failed to get firmware object");
cJSON_Delete(root);
return;
return false;
}
cJSON *version = cJSON_GetObjectItem(firmware, "version");
if (version == NULL) {
ESP_LOGE(TAG, "Failed to get version object");
cJSON_Delete(root);
return;
return false;
}
cJSON *url = cJSON_GetObjectItem(firmware, "url");
if (url == NULL) {
ESP_LOGE(TAG, "Failed to get url object");
cJSON_Delete(root);
return;
return false;
}
firmware_version_ = version->valuestring;
@ -112,6 +115,7 @@ void Ota::CheckVersion() {
} else {
ESP_LOGI(TAG, "Current is the latest version");
}
return true;
}
void Ota::MarkCurrentVersionValid() {

View File

@ -13,14 +13,12 @@ public:
void SetCheckVersionUrl(std::string check_version_url);
void SetHeader(const std::string& key, const std::string& value);
void SetPostData(const std::string& post_data);
void CheckVersion();
bool CheckVersion();
bool HasNewVersion() { return has_new_version_; }
bool HasMqttConfig() { return has_mqtt_config_; }
void StartUpgrade(std::function<void(int progress, size_t speed)> callback);
void MarkCurrentVersionValid();
std::map<std::string, std::string>& GetMqttConfig() { return mqtt_config_; }
private:
std::string check_version_url_;
bool has_new_version_ = false;
@ -29,7 +27,6 @@ private:
std::string firmware_url_;
std::string post_data_;
std::map<std::string, std::string> headers_;
std::map<std::string, std::string> mqtt_config_;
void Upgrade(const std::string& firmware_url);
std::function<void(int progress, size_t speed)> upgrade_callback_;

View File

@ -13,10 +13,14 @@ public:
virtual void OnIncomingJson(std::function<void(const cJSON* root)> callback) = 0;
virtual void SendAudio(const std::string& data) = 0;
virtual void SendText(const std::string& text) = 0;
virtual void SendState(const std::string& state) = 0;
virtual void SendAbort() = 0;
virtual bool OpenAudioChannel() = 0;
virtual void CloseAudioChannel() = 0;
virtual void OnAudioChannelOpened(std::function<void()> callback) = 0;
virtual void OnAudioChannelClosed(std::function<void()> callback) = 0;
virtual bool IsAudioChannelOpened() const = 0;
virtual int GetServerSampleRate() const = 0;
};
#endif // PROTOCOL_H

View File

@ -1,5 +1,7 @@
#include "mqtt_protocol.h"
#include "board.h"
#include "application.h"
#include "settings.h"
#include <esp_log.h>
#include <ml307_mqtt.h>
@ -9,16 +11,9 @@
#define TAG "MQTT"
MqttProtocol::MqttProtocol(std::map<std::string, std::string>& config) {
MqttProtocol::MqttProtocol() {
event_group_handle_ = xEventGroupCreate();
endpoint_ = config["endpoint"];
client_id_ = config["client_id"];
username_ = config["username"];
password_ = config["password"];
subscribe_topic_ = config["subscribe_topic"];
publish_topic_ = config["publish_topic"];
StartMqttClient();
}
@ -39,6 +34,19 @@ bool MqttProtocol::StartMqttClient() {
delete mqtt_;
}
Settings settings("mqtt", false);
endpoint_ = settings.GetString("endpoint");
client_id_ = settings.GetString("client_id");
username_ = settings.GetString("username");
password_ = settings.GetString("password");
subscribe_topic_ = settings.GetString("subscribe_topic");
publish_topic_ = settings.GetString("publish_topic");
if (endpoint_.empty()) {
ESP_LOGE(TAG, "MQTT endpoint is not specified");
return false;
}
mqtt_ = Board::GetInstance().CreateMqtt();
mqtt_->SetKeepAlive(90);
@ -58,9 +66,7 @@ bool MqttProtocol::StartMqttClient() {
cJSON_Delete(root);
return;
}
if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
}
if (strcmp(type->valuestring, "hello") == 0) {
ParseServerHello(root);
} else if (strcmp(type->valuestring, "goodbye") == 0) {
@ -70,6 +76,8 @@ bool MqttProtocol::StartMqttClient() {
on_audio_channel_closed_();
}
}
} else if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
}
cJSON_Delete(root);
});
@ -89,13 +97,17 @@ bool MqttProtocol::StartMqttClient() {
void MqttProtocol::SendText(const std::string& text) {
if (publish_topic_.empty()) {
ESP_LOGE(TAG, "Publish topic is not specified");
return;
}
mqtt_->Publish(publish_topic_, text);
}
void MqttProtocol::SendAudio(const std::string& data) {
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ == nullptr) {
return;
}
std::string nonce(aes_nonce_);
*(uint16_t*)&nonce[2] = htons(data.size());
*(uint32_t*)&nonce[12] = htonl(++local_sequence_);
@ -111,14 +123,26 @@ void MqttProtocol::SendAudio(const std::string& data) {
ESP_LOGE(TAG, "Failed to encrypt audio data");
return;
}
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ == nullptr) {
return;
}
udp_->Send(encrypted);
}
void MqttProtocol::SendState(const std::string& state) {
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"state\",";
message += "\"state\":\"" + state + "\"";
message += "}";
SendText(message);
}
void MqttProtocol::SendAbort() {
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"abort\"";
message += "}";
SendText(message);
}
void MqttProtocol::CloseAudioChannel() {
{
std::lock_guard<std::mutex> lock(channel_mutex_);
@ -129,6 +153,7 @@ void MqttProtocol::CloseAudioChannel() {
}
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"goodbye\"";
message += "}";
SendText(message);
@ -139,8 +164,8 @@ void MqttProtocol::CloseAudioChannel() {
}
bool MqttProtocol::OpenAudioChannel() {
if (!mqtt_->IsConnected()) {
ESP_LOGE(TAG, "MQTT is not connected, try to connect now");
if (mqtt_ == nullptr || !mqtt_->IsConnected()) {
ESP_LOGI(TAG, "MQTT is not connected, try to connect now");
if (!StartMqttClient()) {
ESP_LOGE(TAG, "Failed to connect to MQTT");
return false;
@ -155,7 +180,7 @@ bool MqttProtocol::OpenAudioChannel() {
message += "\"version\": 3,";
message += "\"transport\":\"udp\",";
message += "\"audio_params\":{";
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":60";
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS);
message += "}}";
SendText(message);
@ -185,6 +210,9 @@ bool MqttProtocol::OpenAudioChannel() {
ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
return;
}
if (sequence != remote_sequence_ + 1) {
ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1);
}
std::string decrypted;
size_t decrypted_size = data.size() - aes_nonce_.size();
@ -240,6 +268,15 @@ void MqttProtocol::ParseServerHello(const cJSON* root) {
session_id_ = session_id->valuestring;
}
// Get sample rate from hello message
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (audio_params != NULL) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (sample_rate != NULL) {
server_sample_rate_ = sample_rate->valueint;
}
}
auto udp = cJSON_GetObjectItem(root, "udp");
if (udp == nullptr) {
ESP_LOGE(TAG, "UDP is not specified");
@ -260,6 +297,10 @@ void MqttProtocol::ParseServerHello(const cJSON* root) {
xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
}
int MqttProtocol::GetServerSampleRate() const {
return server_sample_rate_;
}
static const char hex_chars[] = "0123456789ABCDEF";
// 辅助函数,将单个十六进制字符转换为对应的数值
@ -279,3 +320,7 @@ std::string MqttProtocol::DecodeHexString(const std::string& hex_string) {
}
return decoded;
}
bool MqttProtocol::IsAudioChannelOpened() const {
return udp_ != nullptr;
}

View File

@ -21,17 +21,21 @@
class MqttProtocol : public Protocol {
public:
MqttProtocol(std::map<std::string, std::string>& config);
MqttProtocol();
~MqttProtocol();
void OnIncomingAudio(std::function<void(const std::string& data)> callback);
void OnIncomingJson(std::function<void(const cJSON* root)> callback);
void SendAudio(const std::string& data);
void SendText(const std::string& text);
void SendState(const std::string& state);
void SendAbort();
bool OpenAudioChannel();
void CloseAudioChannel();
void OnAudioChannelOpened(std::function<void()> callback);
void OnAudioChannelClosed(std::function<void()> callback);
bool IsAudioChannelOpened() const;
int GetServerSampleRate() const;
private:
EventGroupHandle_t event_group_handle_;
@ -58,6 +62,7 @@ private:
uint32_t local_sequence_;
uint32_t remote_sequence_;
std::string session_id_;
int server_sample_rate_ = 16000;
bool StartMqttClient();
void ParseServerHello(const cJSON* root);

63
main/settings.cc Normal file
View File

@ -0,0 +1,63 @@
#include "settings.h"
#include <esp_log.h>
#include <nvs_flash.h>
#define TAG "Settings"
Settings::Settings(const std::string& ns, bool read_write) : ns_(ns), read_write_(read_write) {
nvs_open(ns.c_str(), read_write_ ? NVS_READWRITE : NVS_READONLY, &nvs_handle_);
}
Settings::~Settings() {
if (nvs_handle_ != 0) {
if (read_write_) {
ESP_ERROR_CHECK(nvs_commit(nvs_handle_));
}
nvs_close(nvs_handle_);
}
}
std::string Settings::GetString(const std::string& key, const std::string& default_value) {
if (nvs_handle_ == 0) {
return default_value;
}
size_t length = 0;
if (nvs_get_str(nvs_handle_, key.c_str(), nullptr, &length) != ESP_OK) {
return default_value;
}
std::string value;
value.resize(length);
ESP_ERROR_CHECK(nvs_get_str(nvs_handle_, key.c_str(), value.data(), &length));
return value;
}
void Settings::SetString(const std::string& key, const std::string& value) {
if (read_write_) {
ESP_ERROR_CHECK(nvs_set_str(nvs_handle_, key.c_str(), value.c_str()));
} else {
ESP_LOGW(TAG, "Namespace %s is not open for writing", ns_.c_str());
}
}
int32_t Settings::GetInt(const std::string& key, int32_t default_value) {
if (nvs_handle_ == 0) {
return default_value;
}
int32_t value;
if (nvs_get_i32(nvs_handle_, key.c_str(), &value) != ESP_OK) {
return default_value;
}
return value;
}
void Settings::SetInt(const std::string& key, int32_t value) {
if (read_write_) {
ESP_ERROR_CHECK(nvs_set_i32(nvs_handle_, key.c_str(), value));
} else {
ESP_LOGW(TAG, "Namespace %s is not open for writing", ns_.c_str());
}
}

23
main/settings.h Normal file
View File

@ -0,0 +1,23 @@
#ifndef SETTINGS_H
#define SETTINGS_H
#include <string>
#include <nvs_flash.h>
class Settings {
public:
Settings(const std::string& ns, bool read_write = false);
~Settings();
std::string GetString(const std::string& key, const std::string& default_value = "");
void SetString(const std::string& key, const std::string& value);
int32_t GetInt(const std::string& key, int32_t default_value = 0);
void SetInt(const std::string& key, int32_t value);
private:
std::string ns_;
nvs_handle_t nvs_handle_ = 0;
bool read_write_ = false;
};
#endif

View File

@ -170,6 +170,7 @@ void WakeWordDetect::StoreWakeWordData(uint16_t* data, size_t samples) {
}
void WakeWordDetect::EncodeWakeWordData() {
xEventGroupClearBits(event_group_, WAKE_WORD_ENCODED_EVENT);
wake_word_opus_.clear();
if (wake_word_encode_task_stack_ == nullptr) {
wake_word_encode_task_stack_ = (StackType_t*)malloc(4096 * 8);
@ -192,7 +193,7 @@ void WakeWordDetect::EncodeWakeWordData() {
this_->wake_word_pcm_.clear();
auto end_time = esp_timer_get_time();
ESP_LOGI(TAG, "Encode wake word opus: %zu bytes in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000);
ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000);
xEventGroupSetBits(this_->event_group_, WAKE_WORD_ENCODED_EVENT);
this_->wake_word_cv_.notify_one();
delete encoder;

View File

@ -5,7 +5,6 @@ CONFIG_BOOTLOADER_APP_ROLLBACK_ENABLE=y
CONFIG_HTTPD_MAX_REQ_HDR_LEN=2048
CONFIG_HTTPD_MAX_URI_LEN=2048
CONFIG_ESP_MAIN_TASK_STACK_SIZE=8192
CONFIG_PARTITION_TABLE_CUSTOM=y
CONFIG_PARTITION_TABLE_CUSTOM_FILENAME="partitions.csv"