mirror of
https://github.com/78/xiaozhi-esp32.git
synced 2025-05-17 15:20:29 +08:00
327 lines
11 KiB
C++
327 lines
11 KiB
C++
#include "mqtt_protocol.h"
|
||
#include "board.h"
|
||
#include "application.h"
|
||
#include "settings.h"
|
||
|
||
#include <esp_log.h>
|
||
#include <ml307_mqtt.h>
|
||
#include <ml307_udp.h>
|
||
#include <cstring>
|
||
#include <arpa/inet.h>
|
||
#include "assets/lang_config.h"
|
||
|
||
#define TAG "MQTT"
|
||
|
||
MqttProtocol::MqttProtocol() {
|
||
event_group_handle_ = xEventGroupCreate();
|
||
}
|
||
|
||
MqttProtocol::~MqttProtocol() {
|
||
ESP_LOGI(TAG, "MqttProtocol deinit");
|
||
if (udp_ != nullptr) {
|
||
delete udp_;
|
||
}
|
||
if (mqtt_ != nullptr) {
|
||
delete mqtt_;
|
||
}
|
||
vEventGroupDelete(event_group_handle_);
|
||
}
|
||
|
||
bool MqttProtocol::Start() {
|
||
return StartMqttClient(false);
|
||
}
|
||
|
||
bool MqttProtocol::StartMqttClient(bool report_error) {
|
||
if (mqtt_ != nullptr) {
|
||
ESP_LOGW(TAG, "Mqtt client already started");
|
||
delete mqtt_;
|
||
}
|
||
|
||
Settings settings("mqtt", false);
|
||
endpoint_ = settings.GetString("endpoint");
|
||
client_id_ = settings.GetString("client_id");
|
||
username_ = settings.GetString("username");
|
||
password_ = settings.GetString("password");
|
||
publish_topic_ = settings.GetString("publish_topic");
|
||
|
||
if (endpoint_.empty()) {
|
||
ESP_LOGW(TAG, "MQTT endpoint is not specified");
|
||
if (report_error) {
|
||
SetError(Lang::Strings::SERVER_NOT_FOUND);
|
||
}
|
||
return false;
|
||
}
|
||
|
||
mqtt_ = Board::GetInstance().CreateMqtt();
|
||
mqtt_->SetKeepAlive(90);
|
||
|
||
mqtt_->OnDisconnected([this]() {
|
||
ESP_LOGI(TAG, "Disconnected from endpoint");
|
||
});
|
||
|
||
mqtt_->OnMessage([this](const std::string& topic, const std::string& payload) {
|
||
cJSON* root = cJSON_Parse(payload.c_str());
|
||
if (root == nullptr) {
|
||
ESP_LOGE(TAG, "Failed to parse json message %s", payload.c_str());
|
||
return;
|
||
}
|
||
cJSON* type = cJSON_GetObjectItem(root, "type");
|
||
if (type == nullptr) {
|
||
ESP_LOGE(TAG, "Message type is not specified");
|
||
cJSON_Delete(root);
|
||
return;
|
||
}
|
||
|
||
if (strcmp(type->valuestring, "hello") == 0) {
|
||
ParseServerHello(root);
|
||
} else if (strcmp(type->valuestring, "goodbye") == 0) {
|
||
auto session_id = cJSON_GetObjectItem(root, "session_id");
|
||
ESP_LOGI(TAG, "Received goodbye message, session_id: %s", session_id ? session_id->valuestring : "null");
|
||
if (session_id == nullptr || session_id_ == session_id->valuestring) {
|
||
Application::GetInstance().Schedule([this]() {
|
||
CloseAudioChannel();
|
||
});
|
||
}
|
||
} else if (on_incoming_json_ != nullptr) {
|
||
on_incoming_json_(root);
|
||
}
|
||
cJSON_Delete(root);
|
||
last_incoming_time_ = std::chrono::steady_clock::now();
|
||
});
|
||
|
||
ESP_LOGI(TAG, "Connecting to endpoint %s", endpoint_.c_str());
|
||
std::string broker_address;
|
||
int broker_port = 8883;
|
||
size_t pos = endpoint_.find(':');
|
||
if (pos != std::string::npos) {
|
||
broker_address = endpoint_.substr(0, pos);
|
||
broker_port = std::stoi(endpoint_.substr(pos + 1));
|
||
} else {
|
||
broker_address = endpoint_;
|
||
}
|
||
if (!mqtt_->Connect(broker_address, broker_port, client_id_, username_, password_)) {
|
||
ESP_LOGE(TAG, "Failed to connect to endpoint");
|
||
SetError(Lang::Strings::SERVER_NOT_CONNECTED);
|
||
return false;
|
||
}
|
||
|
||
ESP_LOGI(TAG, "Connected to endpoint");
|
||
return true;
|
||
}
|
||
|
||
bool MqttProtocol::SendText(const std::string& text) {
|
||
if (publish_topic_.empty()) {
|
||
return false;
|
||
}
|
||
if (!mqtt_->Publish(publish_topic_, text)) {
|
||
ESP_LOGE(TAG, "Failed to publish message: %s", text.c_str());
|
||
SetError(Lang::Strings::SERVER_ERROR);
|
||
return false;
|
||
}
|
||
return true;
|
||
}
|
||
|
||
void MqttProtocol::SendAudio(const std::vector<uint8_t>& data) {
|
||
std::lock_guard<std::mutex> lock(channel_mutex_);
|
||
if (udp_ == nullptr) {
|
||
return;
|
||
}
|
||
|
||
std::string nonce(aes_nonce_);
|
||
*(uint16_t*)&nonce[2] = htons(data.size());
|
||
*(uint32_t*)&nonce[12] = htonl(++local_sequence_);
|
||
|
||
std::string encrypted;
|
||
encrypted.resize(aes_nonce_.size() + data.size());
|
||
memcpy(encrypted.data(), nonce.data(), nonce.size());
|
||
|
||
size_t nc_off = 0;
|
||
uint8_t stream_block[16] = {0};
|
||
if (mbedtls_aes_crypt_ctr(&aes_ctx_, data.size(), &nc_off, (uint8_t*)nonce.c_str(), stream_block,
|
||
(uint8_t*)data.data(), (uint8_t*)&encrypted[nonce.size()]) != 0) {
|
||
ESP_LOGE(TAG, "Failed to encrypt audio data");
|
||
return;
|
||
}
|
||
|
||
busy_sending_audio_ = true;
|
||
udp_->Send(encrypted);
|
||
busy_sending_audio_ = false;
|
||
}
|
||
|
||
void MqttProtocol::CloseAudioChannel() {
|
||
{
|
||
std::lock_guard<std::mutex> lock(channel_mutex_);
|
||
if (udp_ != nullptr) {
|
||
delete udp_;
|
||
udp_ = nullptr;
|
||
}
|
||
}
|
||
|
||
std::string message = "{";
|
||
message += "\"session_id\":\"" + session_id_ + "\",";
|
||
message += "\"type\":\"goodbye\"";
|
||
message += "}";
|
||
SendText(message);
|
||
|
||
if (on_audio_channel_closed_ != nullptr) {
|
||
on_audio_channel_closed_();
|
||
}
|
||
}
|
||
|
||
bool MqttProtocol::OpenAudioChannel() {
|
||
if (mqtt_ == nullptr || !mqtt_->IsConnected()) {
|
||
ESP_LOGI(TAG, "MQTT is not connected, try to connect now");
|
||
if (!StartMqttClient(true)) {
|
||
return false;
|
||
}
|
||
}
|
||
|
||
busy_sending_audio_ = false;
|
||
error_occurred_ = false;
|
||
session_id_ = "";
|
||
xEventGroupClearBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
|
||
|
||
// 发送 hello 消息申请 UDP 通道
|
||
std::string message = "{";
|
||
message += "\"type\":\"hello\",";
|
||
message += "\"version\": 3,";
|
||
message += "\"transport\":\"udp\",";
|
||
message += "\"audio_params\":{";
|
||
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS);
|
||
message += "}}";
|
||
if (!SendText(message)) {
|
||
return false;
|
||
}
|
||
|
||
// 等待服务器响应
|
||
EventBits_t bits = xEventGroupWaitBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
|
||
if (!(bits & MQTT_PROTOCOL_SERVER_HELLO_EVENT)) {
|
||
ESP_LOGE(TAG, "Failed to receive server hello");
|
||
SetError(Lang::Strings::SERVER_TIMEOUT);
|
||
return false;
|
||
}
|
||
|
||
std::lock_guard<std::mutex> lock(channel_mutex_);
|
||
if (udp_ != nullptr) {
|
||
delete udp_;
|
||
}
|
||
udp_ = Board::GetInstance().CreateUdp();
|
||
udp_->OnMessage([this](const std::string& data) {
|
||
/*
|
||
* UDP Encrypted OPUS Packet Format:
|
||
* |type 1u|flags 1u|payload_len 2u|ssrc 4u|timestamp 4u|sequence 4u|
|
||
* |payload payload_len|
|
||
*/
|
||
if (data.size() < sizeof(aes_nonce_)) {
|
||
ESP_LOGE(TAG, "Invalid audio packet size: %zu", data.size());
|
||
return;
|
||
}
|
||
if (data[0] != 0x01) {
|
||
ESP_LOGE(TAG, "Invalid audio packet type: %x", data[0]);
|
||
return;
|
||
}
|
||
uint32_t sequence = ntohl(*(uint32_t*)&data[12]);
|
||
if (sequence < remote_sequence_) {
|
||
ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
|
||
return;
|
||
}
|
||
if (sequence != remote_sequence_ + 1) {
|
||
ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1);
|
||
}
|
||
|
||
std::vector<uint8_t> decrypted;
|
||
size_t decrypted_size = data.size() - aes_nonce_.size();
|
||
size_t nc_off = 0;
|
||
uint8_t stream_block[16] = {0};
|
||
decrypted.resize(decrypted_size);
|
||
auto nonce = (uint8_t*)data.data();
|
||
auto encrypted = (uint8_t*)data.data() + aes_nonce_.size();
|
||
int ret = mbedtls_aes_crypt_ctr(&aes_ctx_, decrypted_size, &nc_off, nonce, stream_block, encrypted, (uint8_t*)decrypted.data());
|
||
if (ret != 0) {
|
||
ESP_LOGE(TAG, "Failed to decrypt audio data, ret: %d", ret);
|
||
return;
|
||
}
|
||
if (on_incoming_audio_ != nullptr) {
|
||
on_incoming_audio_(std::move(decrypted));
|
||
}
|
||
remote_sequence_ = sequence;
|
||
last_incoming_time_ = std::chrono::steady_clock::now();
|
||
});
|
||
|
||
udp_->Connect(udp_server_, udp_port_);
|
||
|
||
if (on_audio_channel_opened_ != nullptr) {
|
||
on_audio_channel_opened_();
|
||
}
|
||
return true;
|
||
}
|
||
|
||
void MqttProtocol::ParseServerHello(const cJSON* root) {
|
||
auto transport = cJSON_GetObjectItem(root, "transport");
|
||
if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) {
|
||
ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
|
||
return;
|
||
}
|
||
|
||
auto session_id = cJSON_GetObjectItem(root, "session_id");
|
||
if (session_id != nullptr) {
|
||
session_id_ = session_id->valuestring;
|
||
ESP_LOGI(TAG, "Session ID: %s", session_id_.c_str());
|
||
}
|
||
|
||
// Get sample rate from hello message
|
||
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
|
||
if (audio_params != NULL) {
|
||
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
|
||
if (sample_rate != NULL) {
|
||
server_sample_rate_ = sample_rate->valueint;
|
||
}
|
||
auto frame_duration = cJSON_GetObjectItem(audio_params, "frame_duration");
|
||
if (frame_duration != NULL) {
|
||
server_frame_duration_ = frame_duration->valueint;
|
||
}
|
||
}
|
||
|
||
auto udp = cJSON_GetObjectItem(root, "udp");
|
||
if (udp == nullptr) {
|
||
ESP_LOGE(TAG, "UDP is not specified");
|
||
return;
|
||
}
|
||
udp_server_ = cJSON_GetObjectItem(udp, "server")->valuestring;
|
||
udp_port_ = cJSON_GetObjectItem(udp, "port")->valueint;
|
||
auto key = cJSON_GetObjectItem(udp, "key")->valuestring;
|
||
auto nonce = cJSON_GetObjectItem(udp, "nonce")->valuestring;
|
||
|
||
// auto encryption = cJSON_GetObjectItem(udp, "encryption")->valuestring;
|
||
// ESP_LOGI(TAG, "UDP server: %s, port: %d, encryption: %s", udp_server_.c_str(), udp_port_, encryption);
|
||
aes_nonce_ = DecodeHexString(nonce);
|
||
mbedtls_aes_init(&aes_ctx_);
|
||
mbedtls_aes_setkey_enc(&aes_ctx_, (const unsigned char*)DecodeHexString(key).c_str(), 128);
|
||
local_sequence_ = 0;
|
||
remote_sequence_ = 0;
|
||
xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
|
||
}
|
||
|
||
static const char hex_chars[] = "0123456789ABCDEF";
|
||
// 辅助函数,将单个十六进制字符转换为对应的数值
|
||
static inline uint8_t CharToHex(char c) {
|
||
if (c >= '0' && c <= '9') return c - '0';
|
||
if (c >= 'A' && c <= 'F') return c - 'A' + 10;
|
||
if (c >= 'a' && c <= 'f') return c - 'a' + 10;
|
||
return 0; // 对于无效输入,返回0
|
||
}
|
||
|
||
std::string MqttProtocol::DecodeHexString(const std::string& hex_string) {
|
||
std::string decoded;
|
||
decoded.reserve(hex_string.size() / 2);
|
||
for (size_t i = 0; i < hex_string.size(); i += 2) {
|
||
char byte = (CharToHex(hex_string[i]) << 4) | CharToHex(hex_string[i + 1]);
|
||
decoded.push_back(byte);
|
||
}
|
||
return decoded;
|
||
}
|
||
|
||
bool MqttProtocol::IsAudioChannelOpened() const {
|
||
return udp_ != nullptr && !error_occurred_ && !IsTimeout();
|
||
}
|