diff --git a/examples/platforms/simulation/dnssd.c b/examples/platforms/simulation/dnssd.c index ab5ae998e..9b83227c6 100644 --- a/examples/platforms/simulation/dnssd.c +++ b/examples/platforms/simulation/dnssd.c @@ -164,4 +164,16 @@ void otPlatDnssdStopIp4AddressResolver(otInstance *aInstance, const otPlatDnssdA OT_UNUSED_VARIABLE(aResolver); } +void otPlatDnssdStartRecordQuerier(otInstance *aInstance, const otPlatDnssdRecordQuerier *aQuerier) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aQuerier); +} + +void otPlatDnssdStopRecordQuerier(otInstance *aInstance, const otPlatDnssdRecordQuerier *aQuerier) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aQuerier); +} + #endif // OPENTHREAD_CONFIG_PLATFORM_DNSSD_ENABLE && OPENTHREAD_SIMULATION_IMPLEMENT_DNSSD diff --git a/include/openthread/instance.h b/include/openthread/instance.h index 5e3c69303..8bc9a1a0f 100644 --- a/include/openthread/instance.h +++ b/include/openthread/instance.h @@ -52,7 +52,7 @@ extern "C" { * * @note This number versions both OpenThread platform and user APIs. */ -#define OPENTHREAD_API_VERSION (500) +#define OPENTHREAD_API_VERSION (501) /** * @addtogroup api-instance diff --git a/include/openthread/mdns.h b/include/openthread/mdns.h index cdad91e53..bb29a843c 100644 --- a/include/openthread/mdns.h +++ b/include/openthread/mdns.h @@ -612,36 +612,17 @@ typedef otPlatDnssdAddressResult otMdnsAddressResult; /** * Represents a record query result. */ -typedef struct otMdnsRecordResult -{ - const char *mFirstLabel; ///< The first label of the name to be queried. - const char *mNextLabels; ///< The rest of the name labels. Does not include domain name. Can be NULL. - uint16_t mRecordType; ///< The record type. - const uint8_t *mRecordData; ///< The record data bytes. - uint16_t mRecordDataLength; ///< Number of bytes in record data. - uint32_t mTtl; ///< TTL in seconds. Zero TTL indicates removal the data. - uint32_t mInfraIfIndex; ///< The infrastructure network interface index. -} otMdnsRecordResult; +typedef otPlatDnssdRecordResult otMdnsRecordResult; /** * Represents the callback function used to report a record querier result. - * - * @param[in] aInstance The OpenThread instance. - * @param[in] aResult The record querier result. */ -typedef void (*otMdnsRecordCallback)(otInstance *aInstance, const otMdnsRecordResult *aResult); +typedef otPlatDnssdRecordCallback otMdnsRecordCallback; /** * Represents a record querier. */ -typedef struct otMdnsRecordQuerier -{ - const char *mFirstLabel; ///< The first label of the name to be queried. MUST NOT be NULL. - const char *mNextLabels; ///< The rest of name labels, excluding domain name. Can be NULL. - uint16_t mRecordType; ///< The record type to query. - uint32_t mInfraIfIndex; ///< The infrastructure network interface index. - otMdnsRecordCallback mCallback; ///< The callback to report result. -} otMdnsRecordQuerier; +typedef otPlatDnssdRecordQuerier otMdnsRecordQuerier; /** * Starts a service browser. diff --git a/include/openthread/platform/dnssd.h b/include/openthread/platform/dnssd.h index a64589bea..beff2fcbb 100644 --- a/include/openthread/platform/dnssd.h +++ b/include/openthread/platform/dnssd.h @@ -539,6 +539,40 @@ typedef struct otPlatDnssdAddressResolver otPlatDnssdAddressCallback mCallback; ///< The callback to report result. } otPlatDnssdAddressResolver; +/** + * Represents a record query result. + */ +typedef struct otPlatDnssdRecordResult +{ + const char *mFirstLabel; ///< The first label of the name to be queried. + const char *mNextLabels; ///< The rest of the name labels. Does not include domain name. Can be NULL. + uint16_t mRecordType; ///< The record type. + const uint8_t *mRecordData; ///< The record data bytes. + uint16_t mRecordDataLength; ///< Number of bytes in record data. + uint32_t mTtl; ///< TTL in seconds. Zero TTL indicates removal the data. + uint32_t mInfraIfIndex; ///< The infrastructure network interface index. +} otPlatDnssdRecordResult; + +/** + * Represents the callback function used to report a record querier result. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aResult The record querier result. + */ +typedef void (*otPlatDnssdRecordCallback)(otInstance *aInstance, const otPlatDnssdRecordResult *aResult); + +/** + * Represents a record querier. + */ +typedef struct otPlatDnssdRecordQuerier +{ + const char *mFirstLabel; ///< The first label of the name to be queried. MUST NOT be NULL. + const char *mNextLabels; ///< The rest of name labels, excluding domain name. Can be NULL. + uint16_t mRecordType; ///< The record type to query. + uint32_t mInfraIfIndex; ///< The infrastructure network interface index. + otPlatDnssdRecordCallback mCallback; ///< The callback to report result. +} otPlatDnssdRecordQuerier; + /** * Starts a service browser. * @@ -718,6 +752,46 @@ void otPlatDnssdStartIp4AddressResolver(otInstance *aInstance, const otPlatDnssd */ void otPlatDnssdStopIp4AddressResolver(otInstance *aInstance, const otPlatDnssdAddressResolver *aResolver); +/** + * Starts a record querier. + * + * Initiates a continuous query for a given `mRecordType` as specified in @p aQuerier. The queried name is specified + * by the combination of `mFirstLabel` and `mNextLabels` (optional rest of the labels) in @p aQuerier. The + * `mFirstLabel` is always non-NULL but `mNextLabels` can be `NULL` if there are no other labels. The `mNextLabels + * does not include the domain name. The reason for a separate first label is to allow it to include a dot `.` + * character (as allowed for service instance labels). + * + * Discovered results should be reported through the `mCallback` function in @p aQuerier, providing the raw record + * data bytes. A removed record data is indicated with a TTL value of zero. The callback may be invoked immediately + * with cached information (if available) and potentially before this function returns. When cached results are used, + * the reported TTL value should reflect the original TTL from the last received response. + * + * Multiple querier instances can be started for the same name, provided they use different callback functions. + * + * OpenThread will only use a record querier for types other than PTR, SRV, TXT, A, and AAAA. For those, specific + * browsers or resolvers are used. The platform implementation, therefore, can choose to restrict its implementation. + * + * The @p aQuerier and all its contained information (strings) are only valid during this call. The platform MUST save + * a copy of the information if it wants to retain the information after returning from this function. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aQuerier The record querier to be started. + */ +void otPlatDnssdStartRecordQuerier(otInstance *aInstance, const otPlatDnssdRecordQuerier *aQuerier); + +/** + * Stops a record querier. + * + * No action is performed if no matching querier with the same name, record type and callback is currently active. + * + * The @p aQuerier and all its contained information (strings) are only valid during this call. The platform MUST save + * a copy of the information if it wants to retain the information after returning from this function. + * + * @param[in] aInstance The OpenThread instance. + * @param[in] aQuerier The record querier to be stopped. + */ +void otPlatDnssdStopRecordQuerier(otInstance *aInstance, const otPlatDnssdRecordQuerier *aQuerier); + /** * @} */ diff --git a/src/core/net/dns_types.cpp b/src/core/net/dns_types.cpp index a4ab14596..92c0444d6 100644 --- a/src/core/net/dns_types.cpp +++ b/src/core/net/dns_types.cpp @@ -1131,6 +1131,32 @@ exit: return error; } +ResourceRecord::TypeInfoString ResourceRecord::TypeToString(uint16_t aRecordType) +{ + static constexpr Stringify::Entry kRecordTypeTable[] = { + {kTypeA, "A"}, {kTypeNs, "NS"}, {kTypeCname, "CNAME"}, {kTypeSoa, "SOA"}, {kTypePtr, "PTR"}, + {kTypeMx, "MX"}, {kTypeTxt, "TXT"}, {kTypeRp, "RP"}, {kTypeAfsdb, "AFSDB"}, {kTypeRt, "RT"}, + {kTypeSig, "SIG"}, {kTypeKey, "KEY"}, {kTypePx, "PX"}, {kTypeAaaa, "AAAA"}, {kTypeSrv, "SRV"}, + {kTypeKx, "KX"}, {kTypeDname, "DNAME"}, {kTypeOpt, "OPT"}, {kTypeNsec, "NSEC"}, {kTypeAny, "ANY"}, + }; + + static_assert(Stringify::IsSorted(kRecordTypeTable), "kRecordTypeTable is not sorted"); + + TypeInfoString string; + const char *lookupResult = Stringify::Lookup(aRecordType, kRecordTypeTable, nullptr); + + if (lookupResult != nullptr) + { + string.Append("%s", lookupResult); + } + else + { + string.Append("RR:%u", aRecordType); + } + + return string; +} + void TxtEntry::Iterator::Init(const uint8_t *aTxtData, uint16_t aTxtDataLength) { SetTxtData(aTxtData); diff --git a/src/core/net/dns_types.hpp b/src/core/net/dns_types.hpp index a95edec94..41e8d59da 100644 --- a/src/core/net/dns_types.hpp +++ b/src/core/net/dns_types.hpp @@ -46,6 +46,7 @@ #include "common/equatable.hpp" #include "common/message.hpp" #include "common/owned_ptr.hpp" +#include "common/string.hpp" #include "crypto/ecdsa.hpp" #include "net/ip4_types.hpp" #include "net/ip6_address.hpp" @@ -1293,6 +1294,10 @@ public: static constexpr uint16_t kClassNone = 254; ///< Class code None (NONE) - RFC 2136. static constexpr uint16_t kClassAny = 255; ///< Class code Any (ANY). + static constexpr uint16_t kTypeStringSize = 17; ///< Size of `TypeInfoString`. + + typedef String TypeInfoString; /// A string to represent a resource record type (human-readable). + /** * Initializes the resource record by setting its type and class. * @@ -1545,6 +1550,15 @@ public: */ static Error DecompressRecordData(const Message &aMessage, uint16_t aOffset, OwnedPtr &aDataMsg); + /** + * Returns a human-readable string representation of a given resource record type. + * + * @param[in] aRecordType The resource record type to convert. + * + * @returns human-readable string representation of a given resource record type. + */ + static TypeInfoString TypeToString(uint16_t aRecordType); + protected: Error ReadName(const Message &aMessage, uint16_t &aOffset, diff --git a/src/core/net/dnssd.cpp b/src/core/net/dnssd.cpp index 128c562c0..086b5a555 100644 --- a/src/core/net/dnssd.cpp +++ b/src/core/net/dnssd.cpp @@ -473,6 +473,50 @@ exit: return; } +void Dnssd::StartRecordQuerier(const RecordQuerier &aQuerier) +{ + VerifyOrExit(IsReady()); + +#if OPENTHREAD_CONFIG_PLATFORM_DNSSD_ALLOW_RUN_TIME_SELECTION + if (mUseNativeMdns) +#endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + { + IgnoreError(Get().StartRecordQuerier(aQuerier)); + ExitNow(); + } +#endif + +#if OPENTHREAD_CONFIG_PLATFORM_DNSSD_ENABLE + otPlatDnssdStartRecordQuerier(&GetInstance(), &aQuerier); +#endif + +exit: + return; +} + +void Dnssd::StopRecordQuerier(const RecordQuerier &aQuerier) +{ + VerifyOrExit(IsReady()); + +#if OPENTHREAD_CONFIG_PLATFORM_DNSSD_ALLOW_RUN_TIME_SELECTION + if (mUseNativeMdns) +#endif +#if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE + { + IgnoreError(Get().StopRecordQuerier(aQuerier)); + ExitNow(); + } +#endif + +#if OPENTHREAD_CONFIG_PLATFORM_DNSSD_ENABLE + otPlatDnssdStopRecordQuerier(&GetInstance(), &aQuerier); +#endif + +exit: + return; +} + void Dnssd::HandleStateChange(void) { #if OPENTHREAD_CONFIG_SRP_SERVER_ADVERTISING_PROXY_ENABLE diff --git a/src/core/net/dnssd.hpp b/src/core/net/dnssd.hpp index 3f1a58107..8b5044799 100644 --- a/src/core/net/dnssd.hpp +++ b/src/core/net/dnssd.hpp @@ -93,12 +93,14 @@ public: typedef otPlatDnssdBrowseCallback BrowseCallback; ///< Browser callback. typedef otPlatDnssdSrvCallback SrvCallback; ///< SRV callback. typedef otPlatDnssdTxtCallback TxtCallback; ///< TXT callback. - typedef otPlatDnssdAddressCallback AddressCallback; ///< Address callback + typedef otPlatDnssdAddressCallback AddressCallback; ///< Address callback. + typedef otPlatDnssdRecordCallback RecordCallback; ///< Record callback. typedef otPlatDnssdBrowseResult BrowseResult; ///< Browser result. typedef otPlatDnssdSrvResult SrvResult; ///< SRV result. typedef otPlatDnssdTxtResult TxtResult; ///< TXT result. typedef otPlatDnssdAddressResult AddressResult; ///< Address result. typedef otPlatDnssdAddressAndTtl AddressAndTtl; ///< Address and TTL. + typedef otPlatDnssdRecordResult RecordResult; ///< Record result. class Host : public otPlatDnssdHost, public Clearable ///< Host information. { @@ -128,6 +130,10 @@ public: { }; + class RecordQuerier : public otPlatDnssdRecordQuerier, public Clearable ///< Record querier. + { + }; + /** * Represents a range of `RequestId` values. * @@ -381,6 +387,26 @@ public: */ void StopIp4AddressResolver(const AddressResolver &aResolver); + /** + * Starts a record querier. + * + * Refer to the documentation for `otPlatDnssdStartRecordQuerier()` for a more detailed description of the + * behavior of this method. + * + * @param[in] aQuerier The querier to be started. + */ + void StartRecordQuerier(const RecordQuerier &aQuerier); + + /** + * Stops a record querier. + * + * Refer to the documentation for `otPlatDnssdStopRecordQuerier()` for a more detailed description of the + * behavior of this method. + * + * @param[in] aQuerier The querier to stop. + */ + void StopRecordQuerier(const RecordQuerier &aQuerier); + #if OPENTHREAD_CONFIG_MULTICAST_DNS_ENABLE /** * Handles native mDNS state change. diff --git a/src/core/net/dnssd_server.cpp b/src/core/net/dnssd_server.cpp index 83cc54bbc..a438f9376 100644 --- a/src/core/net/dnssd_server.cpp +++ b/src/core/net/dnssd_server.cpp @@ -310,26 +310,7 @@ Server::ResponseCode Server::Request::ParseQuestions(uint8_t aTestMode, bool &aS SuccessOrExit(mMessage->Read(offset, question)); offset += sizeof(question); - switch (question.GetType()) - { - case ResourceRecord::kTypePtr: - mType = kPtrQuery; - break; - case ResourceRecord::kTypeSrv: - mType = kSrvQuery; - break; - case ResourceRecord::kTypeTxt: - mType = kTxtQuery; - break; - case ResourceRecord::kTypeAaaa: - mType = kAaaaQuery; - break; - case ResourceRecord::kTypeA: - mType = kAQuery; - break; - default: - ExitNow(rcode = Header::kResponseNotImplemented); - } + mQuestions.mFirstRrType = question.GetType(); if (questionCount > 1) { @@ -338,24 +319,15 @@ Server::ResponseCode Server::Request::ParseQuestions(uint8_t aTestMode, bool &aS VerifyOrExit(questionCount == 2); + // Allow SRV and TXT questions for the same service + // instance name in the same query. + SuccessOrExit(Name::CompareName(*mMessage, offset, *mMessage, sizeof(Header))); SuccessOrExit(mMessage->Read(offset, question)); - switch (question.GetType()) - { - case ResourceRecord::kTypeSrv: - VerifyOrExit(mType == kTxtQuery); - break; + mQuestions.mSecondRrType = question.GetType(); - case ResourceRecord::kTypeTxt: - VerifyOrExit(mType == kSrvQuery); - break; - - default: - ExitNow(); - } - - mType = kSrvTxtQuery; + VerifyOrExit(mQuestions.IsFor(kRrTypeSrv) && mQuestions.IsFor(kRrTypeTxt)); } rcode = Header::kResponseSuccess; @@ -369,7 +341,7 @@ Server::ResponseCode Server::Response::AddQuestionsFrom(const Request &aRequest) ResponseCode rcode = Header::kResponseServerFailure; uint16_t offset; - mType = aRequest.mType; + mQuestions = aRequest.mQuestions; // Read the name from `aRequest.mMessage` and append it as is to // the response message. This ensures all name formats, including @@ -424,24 +396,19 @@ Error Server::Response::ParseQueryName(void) offset = sizeof(Header); SuccessOrExit(error = Name::ReadName(*mMessage, offset, name)); - switch (mType) + if (mQuestions.IsFor(kRrTypePtr)) { - case kPtrQuery: // `mOffsets.mServiceName` may be updated as we read labels and if we // determine that the query name is a sub-type service. mOffsets.mServiceName = sizeof(Header); - break; - - case kSrvQuery: - case kTxtQuery: - case kSrvTxtQuery: + } + else if (mQuestions.IsFor(kRrTypeSrv) || mQuestions.IsFor(kRrTypeTxt)) + { mOffsets.mInstanceName = sizeof(Header); - break; - - case kAaaaQuery: - case kAQuery: + } + else + { mOffsets.mHostName = sizeof(Header); - break; } // Read the query name labels one by one to check if the name is @@ -458,7 +425,7 @@ Error Server::Response::ParseQueryName(void) SuccessOrExit(error = Name::ReadLabel(*mMessage, offset, label, labelLength)); - if ((mType == kPtrQuery) && StringMatch(label, kSubLabel, kStringCaseInsensitiveMatch)) + if (mQuestions.IsFor(kRrTypePtr) && StringMatch(label, kSubLabel, kStringCaseInsensitiveMatch)) { mOffsets.mServiceName = offset; } @@ -692,6 +659,44 @@ exit: return error; } +#if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE +Error Server::Response::AppendKeyRecord(const Srp::Server::Host &aHost) +{ + Ecdsa256KeyRecord keyRecord; + uint32_t ttl; + + keyRecord.Init(); + keyRecord.SetFlags(KeyRecord::kAuthConfidPermitted, KeyRecord::kOwnerNonZone, KeyRecord::kSignatoryFlagGeneral); + keyRecord.SetProtocol(KeyRecord::kProtocolDnsSec); + keyRecord.SetAlgorithm(KeyRecord::kAlgorithmEcdsaP256Sha256); + keyRecord.SetLength(sizeof(Ecdsa256KeyRecord) - sizeof(ResourceRecord)); + keyRecord.SetKey(aHost.GetKey()); + + ttl = TimeMilli::MsecToSec(aHost.GetExpireTime() - TimerMilli::GetNow()); + + return AppendGenericRecord(Ecdsa256KeyRecord::kType, &keyRecord, sizeof(keyRecord), ttl); +} +#endif + +Error Server::Response::AppendGenericRecord(uint16_t aRrType, const void *aData, uint16_t aDataLength, uint32_t aTtl) +{ + Error error = kErrorNone; + ResourceRecord record; + + record.Init(aRrType); + record.SetTtl(aTtl); + record.SetLength(aDataLength); + + SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mHostName, *mMessage)); + SuccessOrExit(error = mMessage->Append(record)); + SuccessOrExit(error = mMessage->AppendBytes(aData, aDataLength)); + + IncResourceRecordCount(); + +exit: + return error; +} + void Server::Response::IncResourceRecordCount(void) { switch (mSection) @@ -709,34 +714,13 @@ void Server::Response::IncResourceRecordCount(void) void Server::Response::Log(void) const { Name::Buffer name; + bool hasTwoQuestions = (mQuestions.mSecondRrType != 0); ReadQueryName(name); - LogInfo("%s query for '%s'", QueryTypeToString(mType), name); -} -const char *Server::Response::QueryTypeToString(QueryType aType) -{ - static const char *const kTypeNames[] = { - "PTR", // (0) kPtrQuery - "SRV", // (1) kSrvQuery - "TXT", // (2) kTxtQuery - "SRV & TXT", // (3) kSrvTxtQuery - "AAAA", // (4) kAaaaQuery - "A", // (5) kAQuery - }; - - struct EumCheck - { - InitEnumValidatorCounter(); - ValidateNextEnum(kPtrQuery); - ValidateNextEnum(kSrvQuery); - ValidateNextEnum(kTxtQuery); - ValidateNextEnum(kSrvTxtQuery); - ValidateNextEnum(kAaaaQuery); - ValidateNextEnum(kAQuery); - }; - - return kTypeNames[aType]; + LogInfo("%s%s%s query for '%s'", ResourceRecord::TypeToString(mQuestions.mFirstRrType).AsCString(), + hasTwoQuestions ? " and " : "", + hasTwoQuestions ? ResourceRecord::TypeToString(mQuestions.mSecondRrType).AsCString() : "", name); } #endif @@ -746,11 +730,9 @@ Error Server::Response::ResolveBySrp(void) { static const Section kSections[] = {kAnswerSection, kAdditionalDataSection}; - Error error = kErrorNotFound; + Error error = kErrorNone; const Srp::Server::Service *matchedService = nullptr; bool found = false; - Section srvSection; - Section txtSection; mSection = kAnswerSection; @@ -761,19 +743,22 @@ Error Server::Response::ResolveBySrp(void) continue; } - if ((mType == kAaaaQuery) || (mType == kAQuery)) + if (QueryNameMatches(host.GetFullName())) { - if (QueryNameMatches(host.GetFullName())) + if (mQuestions.IsFor(kRrTypeAaaa) || mQuestions.IsFor(kRrTypeA)) { - mSection = (mType == kAaaaQuery) ? kAnswerSection : kAdditionalDataSection; - error = AppendHostAddresses(host); - ExitNow(); + mSection = mQuestions.SectionFor(kRrTypeAaaa); + SuccessOrExit(error = AppendHostAddresses(host)); } - continue; - } + if (mQuestions.IsFor(kRrTypeKey)) + { + mSection = kAnswerSection; + SuccessOrExit(error = AppendKeyRecord(host)); + } - // `mType` is PTR or SRV/TXT query + ExitNow(); + } for (const Srp::Server::Service &service : host.GetServices()) { @@ -782,7 +767,7 @@ Error Server::Response::ResolveBySrp(void) continue; } - if (mType == kPtrQuery) + if (mQuestions.IsFor(kRrTypePtr)) { if (QueryNameMatchesService(service)) { @@ -806,9 +791,9 @@ Error Server::Response::ResolveBySrp(void) } } - VerifyOrExit(matchedService != nullptr); + VerifyOrExit(matchedService != nullptr, error = kErrorNotFound); - if (mType == kPtrQuery) + if (mQuestions.IsFor(kRrTypePtr)) { // Skip adding additional records, when answering a // PTR query with more than one answer. This is the @@ -817,9 +802,20 @@ Error Server::Response::ResolveBySrp(void) VerifyOrExit(mHeader.GetAnswerCount() == 1); } + else + { + if (mQuestions.IsFor(kRrTypeKey)) + { + mSection = kAnswerSection; + error = AppendKeyRecord(matchedService->GetHost()); + ExitNow(); + } - srvSection = ((mType == kSrvQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection; - txtSection = ((mType == kTxtQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection; + VerifyOrExit(mQuestions.IsFor(kRrTypeSrv) || mQuestions.IsFor(kRrTypeTxt)); + } + + // Append SRV and TXT records along with associated host AAAA addresses + // in the proper sections. for (Section section : kSections) { @@ -830,12 +826,12 @@ Error Server::Response::ResolveBySrp(void) VerifyOrExit(!(Get().mTestMode & kTestModeEmptyAdditionalSection)); } - if (srvSection == mSection) + if (mSection == mQuestions.SectionFor(kRrTypeSrv)) { SuccessOrExit(error = AppendSrvRecord(*matchedService)); } - if (txtSection == mSection) + if (mSection == mQuestions.SectionFor(kRrTypeTxt)) { SuccessOrExit(error = AppendTxtRecord(*matchedService)); } @@ -971,7 +967,7 @@ void Server::ResolveByProxy(Response &aResponse, const Ip6::MessageInfo &aMessag // We try to convert `aResponse.mMessage` to a `ProxyQuery` by // appending `ProxyQueryInfo` to it. - info.mType = aResponse.mType; + info.mQuestions = aResponse.mQuestions; info.mMessageInfo = aMessageInfo; info.mExpireTime = TimerMilli::GetNow() + kQueryTimeout; info.mOffsets = aResponse.mOffsets; @@ -1099,11 +1095,23 @@ void Server::ConstructFullName(const char *aLabels, Name::Buffer &aFullName) fullName.Append("%s.%s", aLabels, kDefaultDomainName); } -void Server::ConstructFullInstanceName(const char *aInstanceLabel, const char *aServiceType, Name::Buffer &aFullName) +void Server::ConstructFullName(const char *aFirstLabel, const char *aNextLabels, Name::Buffer &aFullName) { StringWriter fullName(aFullName, sizeof(aFullName)); - fullName.Append("%s.%s.%s", aInstanceLabel, aServiceType, kDefaultDomainName); + fullName.Append("%s.", aFirstLabel); + + if (aNextLabels != nullptr) + { + fullName.Append("%s.", aNextLabels); + } + + fullName.Append("%s", kDefaultDomainName); +} + +void Server::ConstructFullInstanceName(const char *aInstanceLabel, const char *aServiceType, Name::Buffer &aFullName) +{ + ConstructFullName(aInstanceLabel, aServiceType, aFullName); } void Server::ConstructFullServiceSubTypeName(const char *aServiceType, @@ -1150,19 +1158,17 @@ void Server::Response::InitFrom(ProxyQuery &aQuery, const ProxyQueryInfo &aInfo) { mMessage.Reset(&aQuery); IgnoreError(mMessage->Read(0, mHeader)); - mType = aInfo.mType; - mOffsets = aInfo.mOffsets; + mQuestions = aInfo.mQuestions; + mOffsets = aInfo.mOffsets; } void Server::Response::Answer(const ServiceInstanceInfo &aInstanceInfo, const Ip6::MessageInfo &aMessageInfo) { static const Section kSections[] = {kAnswerSection, kAdditionalDataSection}; - Error error = kErrorNone; - Section srvSection = ((mType == kSrvQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection; - Section txtSection = ((mType == kTxtQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection; + Error error = kErrorNone; - if (mType == kPtrQuery) + if (mQuestions.IsFor(kRrTypePtr)) { Name::LabelBuffer instanceLabel; @@ -1180,12 +1186,12 @@ void Server::Response::Answer(const ServiceInstanceInfo &aInstanceInfo, const Ip VerifyOrExit(!(Get().mTestMode & kTestModeEmptyAdditionalSection)); } - if (srvSection == mSection) + if (mSection == mQuestions.SectionFor(kRrTypeSrv)) { SuccessOrExit(error = AppendSrvRecord(aInstanceInfo)); } - if (txtSection == mSection) + if (mSection == mQuestions.SectionFor(kRrTypeTxt)) { SuccessOrExit(error = AppendTxtRecord(aInstanceInfo)); } @@ -1204,10 +1210,9 @@ exit: void Server::Response::Answer(const HostInfo &aHostInfo, const Ip6::MessageInfo &aMessageInfo) { - // Caller already ensures that `mType` is either `kAaaaQuery` or - // `kAQuery`. + // Caller already ensures that question is either for AAAA or A record. - AddrType addrType = (mType == kAaaaQuery) ? kIp6AddrType : kIp4AddrType; + AddrType addrType = mQuestions.IsFor(kRrTypeAaaa) ? kIp6AddrType : kIp4AddrType; mSection = kAnswerSection; @@ -1243,21 +1248,13 @@ void Server::HandleDiscoveredServiceInstance(const char *aServiceFullName, const info.ReadFrom(query); - switch (info.mType) + if (info.mQuestions.IsFor(kRrTypePtr)) { - case kPtrQuery: canAnswer = QueryNameMatches(query, aServiceFullName); - break; - - case kSrvQuery: - case kTxtQuery: - case kSrvTxtQuery: + } + else if (info.mQuestions.IsFor(kRrTypeSrv) || info.mQuestions.IsFor(kRrTypeTxt)) + { canAnswer = QueryNameMatches(query, aInstanceInfo.mFullName); - break; - - case kAaaaQuery: - case kAQuery: - break; } if (canAnswer) @@ -1280,22 +1277,17 @@ void Server::HandleDiscoveredHost(const char *aHostFullName, const HostInfo &aHo info.ReadFrom(query); - switch (info.mType) + if (!info.mQuestions.IsFor(kRrTypeAaaa) && !info.mQuestions.IsFor(kRrTypeA)) { - case kAaaaQuery: - case kAQuery: - if (QueryNameMatches(query, aHostFullName)) - { - Response response(GetInstance()); + continue; + } - RemoveQueryAndPrepareResponse(query, info, response); - response.Answer(aHostInfo, info.mMessageInfo); - } + if (QueryNameMatches(query, aHostFullName)) + { + Response response(GetInstance()); - break; - - default: - break; + RemoveQueryAndPrepareResponse(query, info, response); + response.Answer(aHostInfo, info.mMessageInfo); } } } @@ -1316,23 +1308,15 @@ Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, Dns ReadQueryName(*query, aName); info.ReadFrom(*query); - type = kDnsQueryBrowse; + type = kDnsQueryResolveHost; - switch (info.mType) + if (info.mQuestions.IsFor(kRrTypePtr)) + { + type = kDnsQueryBrowse; + } + else if (info.mQuestions.IsFor(kRrTypeSrv) || info.mQuestions.IsFor(kRrTypeTxt)) { - case kPtrQuery: - break; - - case kSrvQuery: - case kTxtQuery: - case kSrvTxtQuery: type = kDnsQueryResolve; - break; - - case kAaaaQuery: - case kAQuery: - type = kDnsQueryResolveHost; - break; } return type; @@ -1493,31 +1477,41 @@ exit: void Server::DiscoveryProxy::Resolve(ProxyQuery &aQuery, ProxyQueryInfo &aInfo) { - ProxyAction action = kNoAction; + // Determine which proxy action to start with based on the query's + // question record type(s). Note that the order in which the record + // types are checked is important. Particularly if the query + // contains questions for both SRV and TXT records, we want to + // start with the `kResolvingSrv` action first. - switch (aInfo.mType) + struct ActionEntry { - case kPtrQuery: - action = kBrowsing; - break; + uint16_t mRrType; + ProxyAction mAction; + }; - case kSrvQuery: - case kSrvTxtQuery: - action = kResolvingSrv; - break; + static const ActionEntry kActionTable[] = { + {kRrTypePtr, kBrowsing}, // PTR -> Browser + {kRrTypeSrv, kResolvingSrv}, // SRV -> SrvResolver + {kRrTypeTxt, kResolvingTxt}, // TXT -> TxtResolver + {kRrTypeAaaa, kResolvingIp6Address}, // AAAA -> Ip6AddressResolver + {kRrTypeA, kResolvingIp4Address}, // A -> Ip4AddressResolver + // Misc -> RecordQuerier + }; - case kTxtQuery: - action = kResolvingTxt; - break; + ProxyAction action; - case kAaaaQuery: - action = kResolvingIp6Address; - break; - case kAQuery: - action = kResolvingIp4Address; - break; + for (const ActionEntry &entry : kActionTable) + { + if (aInfo.mQuestions.IsFor(entry.mRrType)) + { + action = entry.mAction; + ExitNow(); + } } + action = kQueryingRecord; + +exit: Perform(action, aQuery, aInfo); } @@ -1525,6 +1519,7 @@ void Server::DiscoveryProxy::Perform(ProxyAction aAction, ProxyQuery &aQuery, Pr { bool shouldStart; Name::Buffer name; + uint16_t querierRrType; VerifyOrExit(aAction != kNoAction); @@ -1540,7 +1535,9 @@ void Server::DiscoveryProxy::Perform(ProxyAction aAction, ProxyQuery &aQuery, Pr ReadNameFor(aAction, aQuery, aInfo, name); - shouldStart = !HasActive(aAction, name); + querierRrType = (aAction == kQueryingRecord) ? aInfo.mQuestions.mFirstRrType : 0; + + shouldStart = !HasActive(aAction, name, querierRrType); aInfo.mAction = aAction; aInfo.UpdateIn(aQuery); @@ -1572,6 +1569,7 @@ void Server::DiscoveryProxy::ReadNameFor(ProxyAction aAction, break; case kResolvingIp6Address: case kResolvingIp4Address: + case kQueryingRecord: ReadQueryHostName(aQuery, aInfo, aName); break; } @@ -1585,6 +1583,7 @@ void Server::DiscoveryProxy::CancelAction(ProxyQuery &aQuery, ProxyQueryInfo &aI ProxyAction action = aInfo.mAction; Name::Buffer name; + uint16_t querierRrType; VerifyOrExit(mIsRunning); VerifyOrExit(action != kNoAction); @@ -1592,14 +1591,15 @@ void Server::DiscoveryProxy::CancelAction(ProxyQuery &aQuery, ProxyQueryInfo &aI // We first update the `aInfo` on `aQuery` before calling // `HasActive()`. This ensures that the current query is not // taken into account when we try to determine if any query - // is waiting for same `aAction` browser/resolver. + // is waiting for same `action` browser/resolver. ReadNameFor(action, aQuery, aInfo, name); + querierRrType = (action == kQueryingRecord) ? aInfo.mQuestions.mFirstRrType : 0; aInfo.mAction = kNoAction; aInfo.UpdateIn(aQuery); - VerifyOrExit(!HasActive(action, name)); + VerifyOrExit(!HasActive(action, name, querierRrType)); UpdateProxy(kStop, action, aQuery, aInfo, name); exit: @@ -1634,6 +1634,9 @@ void Server::DiscoveryProxy::UpdateProxy(Command aCommand, case kResolvingIp4Address: StartOrStopIp4Resolver(aCommand, aName); break; + case kQueryingRecord: + StartOrStopRecordQuerier(aCommand, aQuery, aInfo); + break; } } @@ -1795,18 +1798,59 @@ void Server::DiscoveryProxy::StartOrStopIp4Resolver(Command aCommand, Name::Buff } } +void Server::DiscoveryProxy::StartOrStopRecordQuerier(Command aCommand, + const ProxyQuery &aQuery, + const ProxyQueryInfo &aInfo) +{ + // Start or stop a record querier. + + Dnssd::RecordQuerier querier; + Name::LabelBuffer firstLabel; + Name::Buffer nextLabels; + uint16_t offset = aInfo.mOffsets.mHostName; + uint8_t labelLength = sizeof(firstLabel); + + IgnoreError(Dns::Name::ReadLabel(aQuery, offset, firstLabel, labelLength)); + IgnoreError(Dns::Name::ReadName(aQuery, offset, nextLabels)); + + querier.mFirstLabel = firstLabel; + querier.mNextLabels = (StripDomainName(nextLabels) == kErrorNone) ? nextLabels : nullptr; + querier.mRecordType = aInfo.mQuestions.mFirstRrType; + querier.mInfraIfIndex = Get().GetIfIndex(); + querier.mCallback = HandleRecordResult; + + switch (aCommand) + { + case kStart: + Get().StartRecordQuerier(querier); + break; + + case kStop: + Get().StopRecordQuerier(querier); + break; + } +} + bool Server::DiscoveryProxy::QueryMatches(const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, ProxyAction aAction, - const Name::Buffer &aName) const + const Name::Buffer &aName, + uint16_t aQuerierRrType) const { // Check whether `aQuery` is performing `aAction` and - // its name matches `aName`. + // its name matches `aName`. The `aQuerierRrType` is only + // used when the action is `kQueryingRecord` to indicate + // which record is being queried. bool matches = false; VerifyOrExit(aInfo.mAction == aAction); + if (aAction == kQueryingRecord) + { + VerifyOrExit(aInfo.mQuestions.IsFor(aQuerierRrType)); + } + switch (aAction) { case kBrowsing: @@ -1818,6 +1862,7 @@ bool Server::DiscoveryProxy::QueryMatches(const ProxyQuery &aQuery, break; case kResolvingIp6Address: case kResolvingIp4Address: + case kQueryingRecord: VerifyOrExit(QueryHostNameMatches(aQuery, aInfo, aName)); break; case kNoAction: @@ -1830,10 +1875,12 @@ exit: return matches; } -bool Server::DiscoveryProxy::HasActive(ProxyAction aAction, const Name::Buffer &aName) const +bool Server::DiscoveryProxy::HasActive(ProxyAction aAction, const Name::Buffer &aName, uint16_t aQuerierRrType) const { - // Determine whether or not we have an active browser/resolver - // corresponding to `aAction` for `aName`. + // Determine whether or not we have an active browser, resolver, or record + // querier corresponding to `aAction` for `aName`. The `aQuerierRrType` + // is only used when the action is `kQueryingRecord` to indicate the + // `RecordQuerier` record type. bool has = false; @@ -1843,7 +1890,7 @@ bool Server::DiscoveryProxy::HasActive(ProxyAction aAction, const Name::Buffer & info.ReadFrom(query); - if (QueryMatches(query, info, aAction, aName)) + if (QueryMatches(query, info, aAction, aName, aQuerierRrType)) { has = true; break; @@ -1999,6 +2046,26 @@ exit: return; } +void Server::DiscoveryProxy::HandleRecordResult(otInstance *aInstance, const otPlatDnssdRecordResult *aResult) +{ + AsCoreType(aInstance).Get().mDiscoveryProxy.HandleRecordResult(*aResult); +} + +void Server::DiscoveryProxy::HandleRecordResult(const Dnssd::RecordResult &aResult) +{ + Name::Buffer name; + + VerifyOrExit(mIsRunning); + VerifyOrExit(aResult.mTtl != 0); + VerifyOrExit(aResult.mInfraIfIndex == Get().GetIfIndex()); + + ConstructFullName(aResult.mFirstLabel, aResult.mNextLabels, name); + HandleResult(kQueryingRecord, name, &Response::AppendGenericRecord, ProxyResult(aResult)); + +exit: + return; +} + void Server::DiscoveryProxy::HandleResult(ProxyAction aAction, const Name::Buffer &aName, ResponseAppender aAppender, @@ -2015,6 +2082,9 @@ void Server::DiscoveryProxy::HandleResult(ProxyAction aAction, ProxyQueryList nextActionQueries; ProxyQueryInfo info; ProxyAction nextAction; + uint16_t querierRrType; + + querierRrType = (aAction == kQueryingRecord) ? aResult.mRecordResult->mRecordType : 0; for (ProxyQuery &query : Get().mProxyQueries) { @@ -2023,7 +2093,7 @@ void Server::DiscoveryProxy::HandleResult(ProxyAction aAction, info.ReadFrom(query); - if (!QueryMatches(query, info, aAction, aName)) + if (!QueryMatches(query, info, aAction, aName, querierRrType)) { continue; } @@ -2038,21 +2108,26 @@ void Server::DiscoveryProxy::HandleResult(ProxyAction aAction, nextAction = kResolvingSrv; break; case kResolvingSrv: - nextAction = (info.mType == kSrvQuery) ? kResolvingIp6Address : kResolvingTxt; + nextAction = (info.mQuestions.IsFor(kRrTypeSrv) && !info.mQuestions.IsFor(kRrTypeTxt)) + ? kResolvingIp6Address + : kResolvingTxt; break; case kResolvingTxt: - nextAction = (info.mType == kTxtQuery) ? kNoAction : kResolvingIp6Address; + nextAction = (info.mQuestions.IsFor(kRrTypeTxt) && !info.mQuestions.IsFor(kRrTypeSrv)) + ? kNoAction + : kResolvingIp6Address; break; case kNoAction: case kResolvingIp6Address: case kResolvingIp4Address: + case kQueryingRecord: break; } shouldFinalize = (nextAction == kNoAction); if ((Get().mTestMode & kTestModeEmptyAdditionalSection) && - IsActionForAdditionalSection(nextAction, info.mType)) + IsActionForAdditionalSection(nextAction, info.mQuestions)) { shouldFinalize = true; } @@ -2123,33 +2198,35 @@ void Server::DiscoveryProxy::HandleResult(ProxyAction aAction, } } -bool Server::DiscoveryProxy::IsActionForAdditionalSection(ProxyAction aAction, QueryType aQueryType) +bool Server::DiscoveryProxy::IsActionForAdditionalSection(ProxyAction aAction, const Questions &aQuestions) { - bool isForAddnlSection = false; + bool isForAddnlSection = false; + uint16_t rrType = 0; switch (aAction) { case kResolvingSrv: - VerifyOrExit((aQueryType == kSrvQuery) || (aQueryType == kSrvTxtQuery)); + rrType = kRrTypeSrv; break; case kResolvingTxt: - VerifyOrExit((aQueryType == kTxtQuery) || (aQueryType == kSrvTxtQuery)); + rrType = kRrTypeTxt; break; case kResolvingIp6Address: - VerifyOrExit(aQueryType == kAaaaQuery); + rrType = kRrTypeAaaa; break; case kResolvingIp4Address: - VerifyOrExit(aQueryType == kAQuery); + rrType = kRrTypeA; break; case kNoAction: case kBrowsing: + case kQueryingRecord: ExitNow(); } - isForAddnlSection = true; + isForAddnlSection = aQuestions.SectionFor(rrType) == kAdditionalDataSection; exit: return isForAddnlSection; @@ -2169,7 +2246,7 @@ Error Server::Response::AppendSrvRecord(const ProxyResult &aResult) const Dnssd::SrvResult *srvResult = aResult.mSrvResult; Name::Buffer fullHostName; - mSection = ((mType == kSrvQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection; + mSection = mQuestions.SectionFor(kRrTypeSrv); ConstructFullName(srvResult->mHostName, fullHostName); @@ -2180,7 +2257,7 @@ Error Server::Response::AppendTxtRecord(const ProxyResult &aResult) { const Dnssd::TxtResult *txtResult = aResult.mTxtResult; - mSection = ((mType == kTxtQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection; + mSection = mQuestions.SectionFor(kRrTypeTxt); return AppendTxtRecord(txtResult->mTxtData, txtResult->mTxtDataLength, txtResult->mTtl); } @@ -2190,7 +2267,7 @@ Error Server::Response::AppendHostIp6Addresses(const ProxyResult &aResult) Error error = kErrorNone; const Dnssd::AddressResult *addrResult = aResult.mAddressResult; - mSection = (mType == kAaaaQuery) ? kAnswerSection : kAdditionalDataSection; + mSection = mQuestions.SectionFor(kRrTypeAaaa); for (uint16_t index = 0; index < addrResult->mAddressesLength; index++) { @@ -2219,7 +2296,7 @@ Error Server::Response::AppendHostIp4Addresses(const ProxyResult &aResult) Error error = kErrorNone; const Dnssd::AddressResult *addrResult = aResult.mAddressResult; - mSection = (mType == kAQuery) ? kAnswerSection : kAdditionalDataSection; + mSection = mQuestions.SectionFor(kRrTypeA); for (uint16_t index = 0; index < addrResult->mAddressesLength; index++) { @@ -2238,6 +2315,15 @@ exit: return error; } +Error Server::Response::AppendGenericRecord(const ProxyResult &aResult) +{ + const Dnssd::RecordResult *result = aResult.mRecordResult; + + mSection = kAnswerSection; + + return AppendGenericRecord(result->mRecordType, result->mRecordData, result->mRecordDataLength, result->mTtl); +} + bool Server::IsProxyAddressValid(const Ip6::Address &aAddress) { return !aAddress.IsLinkLocalUnicast() && !aAddress.IsMulticast() && !aAddress.IsUnspecified() && diff --git a/src/core/net/dnssd_server.hpp b/src/core/net/dnssd_server.hpp index f9eed8c0f..f88a79f0d 100644 --- a/src/core/net/dnssd_server.hpp +++ b/src/core/net/dnssd_server.hpp @@ -311,21 +311,20 @@ private: static constexpr uint32_t kQueryTimeout = OPENTHREAD_CONFIG_DNSSD_QUERY_TIMEOUT; static constexpr uint16_t kMaxConcurrentUpstreamQueries = 32; + static constexpr uint16_t kRrTypeA = ResourceRecord::kTypeA; + static constexpr uint16_t kRrTypeSoa = ResourceRecord::kTypeSoa; + static constexpr uint16_t kRrTypeCname = ResourceRecord::kTypeCname; + static constexpr uint16_t kRrTypePtr = ResourceRecord::kTypePtr; + static constexpr uint16_t kRrTypeTxt = ResourceRecord::kTypeTxt; + static constexpr uint16_t kRrTypeKey = ResourceRecord::kTypeKey; + static constexpr uint16_t kRrTypeAaaa = ResourceRecord::kTypeAaaa; + static constexpr uint16_t kRrTypeSrv = ResourceRecord::kTypeSrv; + typedef Header::Response ResponseCode; typedef Message ProxyQuery; typedef MessageQueue ProxyQueryList; - enum QueryType : uint8_t - { - kPtrQuery, - kSrvQuery, - kTxtQuery, - kSrvTxtQuery, - kAaaaQuery, - kAQuery, - }; - enum Section : uint8_t { kAnswerSection, @@ -346,10 +345,22 @@ private: kResolvingSrv, kResolvingTxt, kResolvingIp6Address, - kResolvingIp4Address + kResolvingIp4Address, + kQueryingRecord, }; #endif + struct Questions + { + Questions(void) { mFirstRrType = 0, mSecondRrType = 0; } + + bool IsFor(uint16_t aRrType) const { return (mFirstRrType == aRrType) || (mSecondRrType == aRrType); } + Section SectionFor(uint16_t aRrType) const { return IsFor(aRrType) ? kAnswerSection : kAdditionalDataSection; } + + uint16_t mFirstRrType; + uint16_t mSecondRrType; + }; + struct Request { ResponseCode ParseQuestions(uint8_t aTestMode, bool &aShouldRespond); @@ -357,7 +368,7 @@ private: const Message *mMessage; const Ip6::MessageInfo *mMessageInfo; Header mHeader; - QueryType mType; + Questions mQuestions; }; struct ProxyQueryInfo; @@ -377,11 +388,13 @@ private: explicit ProxyResult(const Dnssd::SrvResult &aSrvResult) { mSrvResult = &aSrvResult; } explicit ProxyResult(const Dnssd::TxtResult &aTxtResult) { mTxtResult = &aTxtResult; } explicit ProxyResult(const Dnssd::AddressResult &aAddressResult) { mAddressResult = &aAddressResult; } + explicit ProxyResult(const Dnssd::RecordResult &aRecordResult) { mRecordResult = &aRecordResult; } const Dnssd::BrowseResult *mBrowseResult; const Dnssd::SrvResult *mSrvResult; const Dnssd::TxtResult *mTxtResult; const Dnssd::AddressResult *mAddressResult; + const Dnssd::RecordResult *mRecordResult; }; #endif @@ -407,6 +420,7 @@ private: uint16_t aPort); Error AppendTxtRecord(const ServiceInstanceInfo &aInstanceInfo); Error AppendTxtRecord(const void *aTxtData, uint16_t aTxtLength, uint32_t aTtl); + Error AppendGenericRecord(uint16_t aRrType, const void *aData, uint16_t aDataLength, uint32_t aTtl); Error AppendHostAddresses(AddrType aAddrType, const HostInfo &aHostInfo); Error AppendHostAddresses(const ServiceInstanceInfo &aInstanceInfo); Error AppendHostAddresses(AddrType aAddrType, const Ip6::Address *aAddrs, uint16_t aAddrsLength, uint32_t aTtl); @@ -423,6 +437,7 @@ private: Error AppendSrvRecord(const Srp::Server::Service &aService); Error AppendTxtRecord(const Srp::Server::Service &aService); Error AppendHostAddresses(const Srp::Server::Host &aHost); + Error AppendKeyRecord(const Srp::Server::Host &aHost); #endif #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE Error AppendPtrRecord(const ProxyResult &aResult); @@ -430,23 +445,23 @@ private: Error AppendTxtRecord(const ProxyResult &aResult); Error AppendHostIp6Addresses(const ProxyResult &aResult); Error AppendHostIp4Addresses(const ProxyResult &aResult); + Error AppendGenericRecord(const ProxyResult &aResult); #endif #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO) - void Log(void) const; - static const char *QueryTypeToString(QueryType aType); + void Log(void) const; #endif OwnedPtr mMessage; Header mHeader; - QueryType mType; + Questions mQuestions; Section mSection; NameOffsets mOffsets; }; struct ProxyQueryInfo : Message::FooterData { - QueryType mType; + Questions mQuestions; Ip6::MessageInfo mMessageInfo; TimeMilli mExpireTime; NameOffsets mOffsets; @@ -481,11 +496,12 @@ private: void Perform(ProxyAction aAction, ProxyQuery &aQuery, ProxyQueryInfo &aInfo); void ReadNameFor(ProxyAction aAction, ProxyQuery &aQuery, ProxyQueryInfo &aInfo, Name::Buffer &aName) const; - bool HasActive(ProxyAction aAction, const Name::Buffer &aName) const; + bool HasActive(ProxyAction aAction, const Name::Buffer &aName, uint16_t aQuerierRrType) const; bool QueryMatches(const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, ProxyAction aAction, - const Name::Buffer &aName) const; + const Name::Buffer &aName, + uint16_t aQuerierRrType) const; void UpdateProxy(Command aCommand, ProxyAction aAction, const ProxyQuery &aQuery, @@ -496,24 +512,27 @@ private: void StartOrStopTxtResolver(Command aCommand, const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo); void StartOrStopIp6Resolver(Command aCommand, Name::Buffer &aHostName); void StartOrStopIp4Resolver(Command aCommand, Name::Buffer &aHostName); + void StartOrStopRecordQuerier(Command aCommand, const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo); static void HandleBrowseResult(otInstance *aInstance, const otPlatDnssdBrowseResult *aResult); static void HandleSrvResult(otInstance *aInstance, const otPlatDnssdSrvResult *aResult); static void HandleTxtResult(otInstance *aInstance, const otPlatDnssdTxtResult *aResult); static void HandleIp6AddressResult(otInstance *aInstance, const otPlatDnssdAddressResult *aResult); static void HandleIp4AddressResult(otInstance *aInstance, const otPlatDnssdAddressResult *aResult); + static void HandleRecordResult(otInstance *aInstance, const otPlatDnssdRecordResult *aResult); void HandleBrowseResult(const Dnssd::BrowseResult &aResult); void HandleSrvResult(const Dnssd::SrvResult &aResult); void HandleTxtResult(const Dnssd::TxtResult &aResult); void HandleIp6AddressResult(const Dnssd::AddressResult &aResult); void HandleIp4AddressResult(const Dnssd::AddressResult &aResult); + void HandleRecordResult(const Dnssd::RecordResult &aResult); void HandleResult(ProxyAction aAction, const Name::Buffer &aName, ResponseAppender aAppender, const ProxyResult &aResult); - static bool IsActionForAdditionalSection(ProxyAction aAction, QueryType aQueryType); + static bool IsActionForAdditionalSection(ProxyAction aAction, const Questions &aQuestions); bool mIsRunning; }; @@ -539,6 +558,7 @@ private: static Error StripDomainName(const char *aFullName, Name::Buffer &aLabels); static Error StripDomainName(Name::Buffer &aName); static void ConstructFullName(const char *aLabels, Name::Buffer &aFullName); + static void ConstructFullName(const char *aFirstLabel, const char *aNextLabels, Name::Buffer &aFullName); static void ConstructFullInstanceName(const char *aInstanceLabel, const char *aServiceType, Name::Buffer &aFullName); diff --git a/src/ncp/platform/dnssd.cpp b/src/ncp/platform/dnssd.cpp index 0bdf37c7a..af97c3e81 100644 --- a/src/ncp/platform/dnssd.cpp +++ b/src/ncp/platform/dnssd.cpp @@ -165,4 +165,16 @@ void otPlatDnssdStopIp4AddressResolver(otInstance *aInstance, const otPlatDnssdA OT_UNUSED_VARIABLE(aResolver); } +void otPlatDnssdStartRecordQuerier(otInstance *aInstance, const otPlatDnssdRecordQuerier *aQuerier) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aQuerier); +} + +void otPlatDnssdStopRecordQuerier(otInstance *aInstance, const otPlatDnssdRecordQuerier *aQuerier) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aQuerier); +} + #endif // OPENTHREAD_FTD && OPENTHREAD_CONFIG_NCP_DNSSD_ENABLE && OPENTHREAD_CONFIG_PLATFORM_DNSSD_ENABLE diff --git a/tests/scripts/thread-cert/border_router/test_dnssd_server.py b/tests/scripts/thread-cert/border_router/test_dnssd_server.py index cb0d1f8ea..82bc34beb 100755 --- a/tests/scripts/thread-cert/border_router/test_dnssd_server.py +++ b/tests/scripts/thread-cert/border_router/test_dnssd_server.py @@ -196,11 +196,6 @@ class TestDnssdServerOnBr(thread_cert.TestCase): }) # check some invalid queries - for qtype in ['CNAME']: - dig_result = digger.dns_dig(server_addr, host1_full_name, qtype) - self._assert_dig_result_matches(dig_result, { - 'status': 'NOTIMP', - }) for service_name in WRONG_SERVICE_NAMES: dig_result = digger.dns_dig(server_addr, service_name, 'PTR') diff --git a/tests/scripts/thread-cert/border_router/test_dnssd_server_multi_border_routers.py b/tests/scripts/thread-cert/border_router/test_dnssd_server_multi_border_routers.py index 4741049ef..c85e680ce 100755 --- a/tests/scripts/thread-cert/border_router/test_dnssd_server_multi_border_routers.py +++ b/tests/scripts/thread-cert/border_router/test_dnssd_server_multi_border_routers.py @@ -278,11 +278,6 @@ class TestDnssdServerOnMultiBr(thread_cert.TestCase): self._verify_discovery_proxy_meshcop(br2_addr, br2.get_network_name(), host) # 4. Check some invalid queries - for qtype in ['CNAME']: - dig_result = host.dns_dig(br2_addr, host1_full_name, qtype) - self._assert_dig_result_matches(dig_result, { - 'status': 'NOTIMP', - }) for service_name in WRONG_SERVICE_NAMES: dig_result = host.dns_dig(br2_addr, service_name, 'PTR') diff --git a/tests/scripts/thread-cert/node.py b/tests/scripts/thread-cert/node.py index d21ac962f..cce49e15c 100755 --- a/tests/scripts/thread-cert/node.py +++ b/tests/scripts/thread-cert/node.py @@ -3581,6 +3581,52 @@ class NodeImpl: index = index + (5 if result[ins] else 1) return result + def dns_query(self, rrtype, first_label, next_labels, server=None, port=53): + """ + Send a DNS query for a given record type and name. + + Output is an array of records (as dictionary) with string keys and values. + [ + {'RecordType': '25', + 'RecordLength': '78', + 'TTL': '7105', + 'Section': 'answer', + 'Name': 'ins1._IPPS._TCP.DEFAULT.SERVICE.ARPA.', + 'RecordData': '[001900010000a0610...d45d3]' + } + ] + """ + cmd = f'dns query {rrtype} {first_label} {next_labels}' + if server is not None: + cmd += f' {server} {port}' + + self.send_command(cmd) + self.simulator.go(10) + output = self._expect_command_output() + + # Example output: + # DNS query response for ins1._IPPS._TCP.DEFAULT.SERVICE.ARPA. + # 0) + # RecordType:25, RecordLength:78, TTL:7105, Section:answer + # Name:ins1._IPPS._TCP.DEFAULT.SERVICE.ARPA. + # RecordData:[00190001000...cdb] + # Done + + result = [] + index = 1 # Skip first line + while (index < len(output)): + if (index > len(output) - 4): + break + record = {} + for line in output[index + 1:index + 4]: + for item in line.strip().split(','): + k, v = item.split(':') + record[k.strip()] = v.strip() + result.append(record) + index += 4 + + return result + def set_mliid(self, mliid: str): cmd = f'mliid {mliid}' self.send_command(cmd) diff --git a/tests/scripts/thread-cert/test_dnssd.py b/tests/scripts/thread-cert/test_dnssd.py index 38c4f85ef..5642a0e1e 100755 --- a/tests/scripts/thread-cert/test_dnssd.py +++ b/tests/scripts/thread-cert/test_dnssd.py @@ -209,6 +209,54 @@ class TestDnssd(thread_cert.TestCase): service_instance = client1.dns_resolve_service('ins4', f'{SERVICE}.{DOMAIN}'.upper(), server.get_mleid(), 53) self._assert_service_instance_equal(service_instance, instance4_verify_info) + #--------------------------------------------------------------- + # Query for KEY record for `ins1` service name + + records = client1.dns_query(25, 'ins1', f'{SERVICE}.{DOMAIN}'.upper(), server.get_mleid(), 53) + self.assertEqual(len(records), 1) + record = records[0] + self.assertEqual(int(record['RecordType']), 25) + self.assertEqual(int(record['RecordLength']), 78) + self.assertTrue(int(record['TTL']) > 0) + self.assertEqual(record['Section'], 'answer') + self.assertEqual(record['Name'].lower(), 'ins1._ipps._tcp.default.service.arpa.') + self.assertIn('RecordData', record) + + #--------------------------------------------------------------- + # Query for SRV record for `ins1` service name + + records = client1.dns_query(33, 'ins1', f'{SERVICE}.{DOMAIN}'.upper(), server.get_mleid(), 53) + self.assertEqual(len(records), 4) + + # SRV record in answer section + record = records[0] + self.assertEqual(int(record['RecordType']), 33) + self.assertTrue(int(record['RecordLength']) > 0) + self.assertTrue(int(record['TTL']) > 0) + self.assertEqual(record['Section'], 'answer') + self.assertEqual(record['Name'].lower(), 'ins1._ipps._tcp.default.service.arpa.') + self.assertIn('RecordData', record) + + # Other records TXT and A in additional section + for record in records[1:]: + self.assertTrue(int(record['RecordLength']) > 0) + self.assertIn('RecordData', record) + self.assertTrue(int(record['TTL']) > 0) + self.assertEqual(record['Section'], 'additional') + rrtype = int(record['RecordType']) + self.assertIn(rrtype, [16, 28]) # TXT and AAAA + if rrtype == 16: + self.assertEqual(record['Name'].lower(), 'ins1._ipps._tcp.default.service.arpa.') + self.assertEqual(record['RecordData'], '[00]') + else: + self.assertEqual(record['Name'].lower(), 'host1.default.service.arpa.') + + #--------------------------------------------------------------- + # Query for non-existing A record for `ins1` service name + + records = client1.dns_query(1, 'ins1', f'{SERVICE}.{DOMAIN}'.upper(), server.get_mleid(), 53) + self.assertEqual(len(records), 0) + def _assert_service_instance_equal(self, instance, info): self.assertEqual(instance['host'].lower(), info['host'].lower(), instance) for f in ('port', 'priority', 'weight', 'txt_data'): diff --git a/tests/unit/test_dns_client.cpp b/tests/unit/test_dns_client.cpp index b8405d856..62118adc4 100644 --- a/tests/unit/test_dns_client.cpp +++ b/tests/unit/test_dns_client.cpp @@ -542,6 +542,133 @@ exit: return; } +//- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +static constexpr uint16_t kMaxRecords = 16; + +struct QueryRecordInfo +{ + struct Record : public Dns::Client::RecordInfo + { + static constexpr uint16_t kMaxRecordDataSize = 200; + + void Init(void) + { + ClearAllBytes(*this); + mNameBuffer = mName; + mNameBufferSize = sizeof(mName); + mDataBuffer = mData; + mDataBufferSize = sizeof(mData); + } + + uint8_t mData[kMaxRecordDataSize]; + char mName[Dns::Name::kMaxNameSize]; + }; + + void Reset(void) { memset(this, 0, sizeof(*this)); }; + + uint16_t mCallbackCount; + Error mError; + char mQueryName[Dns::Name::kMaxNameSize]; + Record mRecords[kMaxRecords]; + uint16_t mNumRecords; +}; + +static QueryRecordInfo sQueryRecordInfo; + +void RecordCallback(otError aError, const otDnsRecordResponse *aResponse, void *aContext) +{ + static constexpr uint16_t kMaxStringSize = 400; + + const Dns::Client::RecordResponse &response = AsCoreType(aResponse); + + Log("RecordCallback"); + Log(" Error: %s", ErrorToString(aError)); + + VerifyOrQuit(aContext == sInstance); + + sQueryRecordInfo.mCallbackCount++; + sQueryRecordInfo.mError = aError; + sQueryRecordInfo.mNumRecords = 0; + + SuccessOrExit(aError); + + SuccessOrQuit(response.GetQueryName(sQueryRecordInfo.mQueryName, sizeof(sQueryRecordInfo.mQueryName))); + Log(" QueryName: %s", sQueryRecordInfo.mQueryName); + + for (uint8_t index = 0; index < kMaxRecords; index++) + { + Error error; + uint32_t ttl; + + sQueryRecordInfo.mRecords[index].Init(); + + error = response.GetRecordInfo(index, sQueryRecordInfo.mRecords[index]); + + if (error == kErrorNotFound) + { + sQueryRecordInfo.mNumRecords = index; + break; + } + + SuccessOrQuit(error); + } + + Log(" NumRecords: %u", sQueryRecordInfo.mNumRecords); + + for (uint16_t index = 0; index < sQueryRecordInfo.mNumRecords; index++) + { + const QueryRecordInfo::Record &record = sQueryRecordInfo.mRecords[index]; + String string; + uint16_t rrType; + + string.AppendHexBytes(record.mDataBuffer, record.mDataBufferSize); + rrType = record.mRecordType; + + Log(" Record %u", index); + Log(" Name: %s", record.mNameBuffer); + Log(" Type: %u (%s)", rrType, Dns::ResourceRecord::TypeToString(rrType).AsCString()); + Log(" Data: %s", string.AsCString()); + } + +exit: + return; +} + +void ValidateSrvRecordData(const QueryRecordInfo::Record &aRecord, const char *aFullHostName) +{ + // Validate that the read SRV record data contains + // the uncompressed host name. + + Message *data = sInstance->Get().Allocate(Message::kTypeOther); + uint16_t offset = sizeof(Dns::SrvRecord) - sizeof(Dns::ResourceRecord); + + VerifyOrQuit(data != nullptr); + SuccessOrQuit(data->AppendBytes(aRecord.mDataBuffer, aRecord.mRecordLength)); + + SuccessOrQuit(Dns::Name::CompareName(*data, offset, aFullHostName)); + VerifyOrQuit(offset == data->GetLength()); + + data->Free(); +} + +void ValidatePtrRecordData(const QueryRecordInfo::Record &aRecord, const char *aFullInstanceName) +{ + // Validate that the read PTR record data contains + // the uncompressed service instance name. + + Message *data = sInstance->Get().Allocate(Message::kTypeOther); + uint16_t offset = 0; + + VerifyOrQuit(data != nullptr); + SuccessOrQuit(data->AppendBytes(aRecord.mDataBuffer, aRecord.mRecordLength)); + + SuccessOrQuit(Dns::Name::CompareName(*data, offset, aFullInstanceName)); + VerifyOrQuit(offset == data->GetLength()); + + data->Free(); +} + //---------------------------------------------------------------------------------------------------------------------- void TestDnsClient(void) @@ -695,6 +822,160 @@ void TestDnsClient(void) VerifyOrQuit(sAddressInfo.mCallbackCount == 1); VerifyOrQuit(sAddressInfo.mError != kErrorNone); + //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // Validate DNS Client `QueryRecord()` for host name + + sQueryRecordInfo.Reset(); + Log("QueryRecord(%s) for KEY RR", kHostFullName); + SuccessOrQuit(dnsClient->QueryRecord(Dns::ResourceRecord::kTypeKey, kHostName, "default.service.arpa.", + RecordCallback, sInstance)); + AdvanceTime(100); + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 1); + SuccessOrQuit(sQueryRecordInfo.mError); + VerifyOrQuit(sQueryRecordInfo.mNumRecords == 1); + + VerifyOrQuit(!strcmp(sQueryRecordInfo.mRecords[0].mNameBuffer, kHostFullName)); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordType == Dns::ResourceRecord::kTypeKey); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordLength == sizeof(Dns::Ecdsa256KeyRecord)); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mTtl > 0); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mDataBufferSize == sizeof(Dns::Ecdsa256KeyRecord)); + VerifyOrQuit(MapEnum(sQueryRecordInfo.mRecords[0].mSection) == Dns::Client::RecordInfo::kSectionAnswer); + + sQueryRecordInfo.Reset(); + Log("QueryRecord(%s) for misc RR", kHostFullName); + SuccessOrQuit(dnsClient->QueryRecord(Dns::ResourceRecord::kTypeCname, kHostName, "default.service.arpa.", + RecordCallback, sInstance)); + AdvanceTime(100); + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 1); + SuccessOrQuit(sQueryRecordInfo.mError); + VerifyOrQuit(sQueryRecordInfo.mNumRecords == 0); + + //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // Validate DNS Client `QueryRecord()` for service instance name and KEY record + + sQueryRecordInfo.Reset(); + Log("QueryRecord(%s) for KEY RR", kInstance1FullName); + SuccessOrQuit(dnsClient->QueryRecord(Dns::ResourceRecord::kTypeKey, kInstance1Label, kService1FullName, + RecordCallback, sInstance)); + AdvanceTime(100); + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 1); + SuccessOrQuit(sQueryRecordInfo.mError); + VerifyOrQuit(sQueryRecordInfo.mNumRecords == 1); + + VerifyOrQuit(!strcmp(sQueryRecordInfo.mRecords[0].mNameBuffer, kInstance1FullName)); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordType == Dns::ResourceRecord::kTypeKey); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordLength == sizeof(Dns::Ecdsa256KeyRecord)); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mTtl > 0); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mDataBufferSize == sizeof(Dns::Ecdsa256KeyRecord)); + VerifyOrQuit(MapEnum(sQueryRecordInfo.mRecords[0].mSection) == Dns::Client::RecordInfo::kSectionAnswer); + + sQueryRecordInfo.Reset(); + Log("QueryRecord(%s) for misc RR", kInstance1FullName); + SuccessOrQuit(dnsClient->QueryRecord(Dns::ResourceRecord::kTypeCname, kInstance1Label, kService1FullName, + RecordCallback, sInstance)); + AdvanceTime(100); + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 1); + SuccessOrQuit(sQueryRecordInfo.mError); + VerifyOrQuit(sQueryRecordInfo.mNumRecords == 0); + + //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // Validate DNS Client `QueryRecord()` for service instance name and SRV record + + sQueryRecordInfo.Reset(); + Log("QueryRecord(%s) for SRV record", kInstance1FullName); + SuccessOrQuit(dnsClient->QueryRecord(Dns::ResourceRecord::kTypeSrv, kInstance1Label, kService1FullName, + RecordCallback, sInstance)); + AdvanceTime(100); + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 1); + SuccessOrQuit(sQueryRecordInfo.mError); + VerifyOrQuit(sQueryRecordInfo.mNumRecords == 4); + + VerifyOrQuit(!strcmp(sQueryRecordInfo.mRecords[0].mNameBuffer, kInstance1FullName)); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordType == Dns::ResourceRecord::kTypeSrv); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordLength > 0); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mTtl > 0); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mDataBufferSize == sQueryRecordInfo.mRecords[0].mRecordLength); + VerifyOrQuit(MapEnum(sQueryRecordInfo.mRecords[0].mSection) == Dns::Client::RecordInfo::kSectionAnswer); + + ValidateSrvRecordData(sQueryRecordInfo.mRecords[0], kHostFullName); + + // Validate the records in additional data (TXT and two AAAA). + + for (uint8_t index = 1; index < 4; index++) + { + const QueryRecordInfo::Record &record = sQueryRecordInfo.mRecords[index]; + + VerifyOrQuit(record.mRecordLength > 0); + VerifyOrQuit(record.mTtl > 0); + VerifyOrQuit(record.mDataBufferSize == record.mRecordLength); + VerifyOrQuit(MapEnum(record.mSection) == Dns::Client::RecordInfo::kSectionAdditional); + + switch (record.mRecordType) + { + case Dns::ResourceRecord::kTypeTxt: + VerifyOrQuit(!strcmp(record.mNameBuffer, kInstance1FullName)); + break; + case Dns::ResourceRecord::kTypeAaaa: + VerifyOrQuit(!strcmp(record.mNameBuffer, kHostFullName)); + VerifyOrQuit(record.mRecordLength == sizeof(Ip6::Address)); + break; + default: + VerifyOrQuit(false); + break; + } + } + + //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // Validate DNS Client `QueryRecord()` for PTR record + + sQueryRecordInfo.Reset(); + Log("QueryRecord(%s) for PTR record", kService1FullName); + SuccessOrQuit(dnsClient->QueryRecord(Dns::ResourceRecord::kTypePtr, "_srv", "_udp.default.service.arpa.", + RecordCallback, sInstance)); + AdvanceTime(100); + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 1); + SuccessOrQuit(sQueryRecordInfo.mError); + VerifyOrQuit(sQueryRecordInfo.mNumRecords == 5); + + VerifyOrQuit(!strcmp(sQueryRecordInfo.mRecords[0].mNameBuffer, kService1FullName)); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordType == Dns::ResourceRecord::kTypePtr); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordLength > 0); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mTtl > 0); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mDataBufferSize == sQueryRecordInfo.mRecords[0].mRecordLength); + VerifyOrQuit(MapEnum(sQueryRecordInfo.mRecords[0].mSection) == Dns::Client::RecordInfo::kSectionAnswer); + + ValidatePtrRecordData(sQueryRecordInfo.mRecords[0], kInstance1FullName); + + // Validate the records in additional data (SRV, TXT and two AAAA). + + for (uint8_t index = 1; index < 5; index++) + { + const QueryRecordInfo::Record &record = sQueryRecordInfo.mRecords[index]; + + VerifyOrQuit(record.mRecordLength > 0); + VerifyOrQuit(record.mTtl > 0); + VerifyOrQuit(record.mDataBufferSize == record.mRecordLength); + VerifyOrQuit(MapEnum(record.mSection) == Dns::Client::RecordInfo::kSectionAdditional); + + switch (record.mRecordType) + { + case Dns::ResourceRecord::kTypeSrv: + VerifyOrQuit(!strcmp(record.mNameBuffer, kInstance1FullName)); + ValidateSrvRecordData(record, kHostFullName); + break; + case Dns::ResourceRecord::kTypeTxt: + VerifyOrQuit(!strcmp(record.mNameBuffer, kInstance1FullName)); + break; + case Dns::ResourceRecord::kTypeAaaa: + VerifyOrQuit(!strcmp(record.mNameBuffer, kHostFullName)); + VerifyOrQuit(record.mRecordLength == sizeof(Ip6::Address)); + break; + default: + VerifyOrQuit(false); + break; + } + } + //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // Validate DNS Client `Browse()` diff --git a/tests/unit/test_dnssd_discovery_proxy.cpp b/tests/unit/test_dnssd_discovery_proxy.cpp index 1c6db5f1b..7470b3127 100644 --- a/tests/unit/test_dnssd_discovery_proxy.cpp +++ b/tests/unit/test_dnssd_discovery_proxy.cpp @@ -487,6 +487,99 @@ exit: return; } +//- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +static constexpr uint16_t kMaxRecordDataSize = 128; +static constexpr uint16_t kMaxRecords = 4; + +struct QueryRecordInfo +{ + struct Record : public Dns::Client::RecordInfo + { + void Init(void) + { + ClearAllBytes(*this); + mNameBuffer = mName; + mNameBufferSize = sizeof(mName); + mDataBuffer = mData; + mDataBufferSize = sizeof(mData); + } + + uint8_t mData[kMaxRecordDataSize]; + char mName[Dns::Name::kMaxNameSize]; + }; + + void Reset(void) { memset(this, 0, sizeof(*this)); }; + + uint16_t mCallbackCount; + Error mError; + char mQueryName[Dns::Name::kMaxNameSize]; + Record mRecords[kMaxRecords]; + uint16_t mNumRecords; +}; + +static QueryRecordInfo sQueryRecordInfo; + +void RecordCallback(otError aError, const otDnsRecordResponse *aResponse, void *aContext) +{ + static constexpr uint16_t kMaxStringSize = 400; + + const Dns::Client::RecordResponse &response = AsCoreType(aResponse); + + Log("RecordCallback"); + Log(" Error: %s", ErrorToString(aError)); + + VerifyOrQuit(aContext == sInstance); + + sQueryRecordInfo.mCallbackCount++; + sQueryRecordInfo.mError = aError; + sQueryRecordInfo.mNumRecords = 0; + + SuccessOrExit(aError); + + SuccessOrQuit(response.GetQueryName(sQueryRecordInfo.mQueryName, sizeof(sQueryRecordInfo.mQueryName))); + Log(" QueryName: %s", sQueryRecordInfo.mQueryName); + + for (uint8_t index = 0; index < kMaxRecords; index++) + { + Error error; + uint32_t ttl; + + sQueryRecordInfo.mRecords[index].Init(); + + error = response.GetRecordInfo(index, sQueryRecordInfo.mRecords[index]); + + if (error == kErrorNotFound) + { + sQueryRecordInfo.mNumRecords = index; + break; + } + + SuccessOrQuit(error); + } + + Log(" NumRecords: %u", sQueryRecordInfo.mNumRecords); + + for (uint16_t index = 0; index < sQueryRecordInfo.mNumRecords; index++) + { + const QueryRecordInfo::Record &record = sQueryRecordInfo.mRecords[index]; + String string; + uint16_t rrType; + + string.AppendHexBytes(record.mDataBuffer, record.mDataBufferSize); + rrType = record.mRecordType; + + Log(" Record %u", index); + Log(" Name: %s", record.mNameBuffer); + Log(" Type: %u (%s)", rrType, Dns::ResourceRecord::TypeToString(rrType).AsCString()); + Log(" TTL: %lu", ToUlong(record.mTtl)); + Log(" Data: %s", string.AsCString()); + } + +exit: + return; +} + //---------------------------------------------------------------------------------------------------------------------- // otPlatDnssd APIs @@ -582,6 +675,28 @@ struct IpAddrResolverInfo : public Clearable otPlatDnssdAddressCallback mCallback; }; +struct RecordQuerierInfo : public Clearable +{ + bool NameMatches(const char *aFirsLabel, const char *aNextLabels) const + { + return !strcmp(mFirstLabel, aFirsLabel) && !strcmp(mNextLabels, aNextLabels == nullptr ? "" : aNextLabels); + } + + void UpdateFrom(const otPlatDnssdRecordQuerier *aQuerier) + { + mCallCount++; + CopyString(mFirstLabel, aQuerier->mFirstLabel); + CopyString(mNextLabels, aQuerier->mNextLabels); + mCallback = aQuerier->mCallback; + } + + uint16_t mCallCount; + char mFirstLabel[Dns::Name::kMaxLabelSize]; + char mNextLabels[Dns::Name::kMaxNameSize]; + uint16_t mRecordType; + otPlatDnssdRecordCallback mCallback; +}; + struct InvokeOnStart : public Clearable { // When not null, these entries are used to invoke callback @@ -593,6 +708,7 @@ struct InvokeOnStart : public Clearable const otPlatDnssdTxtResult *mTxtResult; const otPlatDnssdAddressResult *mIp6AddrResult; const otPlatDnssdAddressResult *mIp4AddrResult; + const otPlatDnssdRecordResult *mRecordResult; }; static BrowserInfo sStartBrowserInfo; @@ -605,6 +721,8 @@ static IpAddrResolverInfo sStartIp6AddrResolverInfo; static IpAddrResolverInfo sStopIp6AddrResolverInfo; static IpAddrResolverInfo sStartIp4AddrResolverInfo; static IpAddrResolverInfo sStopIp4AddrResolverInfo; +static RecordQuerierInfo sStartRecordQuerierInfo; +static RecordQuerierInfo sStopRecordQuerierInfo; static InvokeOnStart sInvokeOnStart; @@ -620,6 +738,8 @@ void ResetPlatDnssdApiInfo(void) sStopIp6AddrResolverInfo.Clear(); sStartIp4AddrResolverInfo.Clear(); sStopIp4AddrResolverInfo.Clear(); + sStartRecordQuerierInfo.Clear(); + sStopRecordQuerierInfo.Clear(); sInvokeOnStart.Clear(); } @@ -694,6 +814,26 @@ void InvokeIp4AddrResolverCallback(const otPlatDnssdAddressCallback aCallback, c aCallback(sInstance, &aResult); } +void InvokeRecordQuerierCallback(const otPlatDnssdRecordCallback aCallback, const otPlatDnssdRecordResult &aResult) +{ + static constexpr uint16_t kMaxDataStringSize = 400; + + String string; + + string.AppendHexBytes(aResult.mRecordData, aResult.mRecordDataLength); + + Log("Invoking record callback"); + Log(" firstLabel : %s", aResult.mFirstLabel); + Log(" nextLabels : %s", StringNullCheck(aResult.mNextLabels)); + Log(" recordType : %s", Dns::ResourceRecord::TypeToString(aResult.mRecordType).AsCString()); + Log(" ttl : %u", aResult.mTtl); + Log(" if-index : %u", aResult.mInfraIfIndex); + Log(" dataLength : %u", aResult.mRecordDataLength); + Log(" data : %s", string.AsCString()); + + aCallback(sInstance, &aResult); +} + otPlatDnssdState otPlatDnssdGetState(otInstance *aInstance) { OT_UNUSED_VARIABLE(aInstance); @@ -856,6 +996,42 @@ void otPlatDnssdStopIp4AddressResolver(otInstance *aInstance, const otPlatDnssdA } } +void otPlatDnssdStartRecordQuerier(otInstance *aInstance, const otPlatDnssdRecordQuerier *aQuerier) +{ + VerifyOrQuit(aQuerier != nullptr); + + Log("otPlatDnssdStartRecordQuerier(%s, %s, %s)", aQuerier->mFirstLabel, StringNullCheck(aQuerier->mNextLabels), + Dns::ResourceRecord::TypeToString(aQuerier->mRecordType).AsCString()); + + VerifyOrQuit(aInstance == sInstance); + VerifyOrQuit(aQuerier->mInfraIfIndex == kInfraIfIndex); + + sStartRecordQuerierInfo.UpdateFrom(aQuerier); + + if (sInvokeOnStart.mRecordResult != nullptr) + { + InvokeRecordQuerierCallback(aQuerier->mCallback, *sInvokeOnStart.mRecordResult); + } +} + +void otPlatDnssdStopRecordQuerier(otInstance *aInstance, const otPlatDnssdRecordQuerier *aQuerier) +{ + VerifyOrQuit(aQuerier != nullptr); + + Log("otPlatDnssdStopRecordQuerier(\"%s, %s, %s\")", aQuerier->mFirstLabel, StringNullCheck(aQuerier->mNextLabels), + Dns::ResourceRecord::TypeToString(aQuerier->mRecordType).AsCString()); + + VerifyOrQuit(aInstance == sInstance); + VerifyOrQuit(aQuerier->mInfraIfIndex == kInfraIfIndex); + + sStopRecordQuerierInfo.UpdateFrom(aQuerier); + + if (sInvokeOnStart.mRecordResult != nullptr) + { + InvokeRecordQuerierCallback(aQuerier->mCallback, *sInvokeOnStart.mRecordResult); + } +} + //---------------------------------------------------------------------------------------------------------------------- void TestProxyBasic(void) @@ -863,6 +1039,7 @@ void TestProxyBasic(void) static constexpr uint32_t kTtl = 300; const uint8_t kTxtData[] = {3, 'A', '=', '1', 0}; + const uint8_t kKeyData[] = {0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99}; Srp::Server *srpServer; Srp::Client *srpClient; @@ -874,6 +1051,7 @@ void TestProxyBasic(void) Dnssd::AddressResult ip6AddrrResult; Dnssd::AddressResult ip4AddrrResult; Dnssd::AddressAndTtl addressAndTtl; + Dnssd::RecordResult recordResult; NetworkData::ExternalRouteConfig routeConfig; Ip6::Address address; @@ -942,6 +1120,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartBrowserInfo.ServiceTypeMatches("_avenger._udp")); @@ -970,6 +1150,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopBrowserInfo.ServiceTypeMatches("_avenger._udp")); VerifyOrQuit(sStopBrowserInfo.mCallback == sStartBrowserInfo.mCallback); @@ -1001,6 +1183,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopSrvResolverInfo.ServiceTypeMatches("_avenger._udp")); VerifyOrQuit(sStopSrvResolverInfo.ServiceInstanceMatches("hulk")); @@ -1035,6 +1219,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopTxtResolverInfo.ServiceTypeMatches("_avenger._udp")); VerifyOrQuit(sStopTxtResolverInfo.ServiceInstanceMatches("hulk")); @@ -1069,6 +1255,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 1); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.HostNameMatches("compound")); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallback == sStartIp6AddrResolverInfo.mCallback); @@ -1114,6 +1302,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartSrvResolverInfo.ServiceTypeMatches("_avenger._udp")); VerifyOrQuit(sStartSrvResolverInfo.ServiceInstanceMatches("iron.man")); @@ -1144,6 +1334,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sResolveServiceInfo.mCallbackCount == 0); @@ -1170,6 +1362,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sResolveServiceInfo.mCallbackCount == 0); @@ -1204,6 +1398,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sResolveServiceInfo.mCallbackCount == 0); @@ -1233,6 +1429,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 1); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.HostNameMatches("starktower")); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallback == sStartIp6AddrResolverInfo.mCallback); @@ -1273,6 +1471,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.HostNameMatches("earth")); @@ -1304,6 +1504,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 1); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.HostNameMatches("earth")); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallback == sStartIp6AddrResolverInfo.mCallback); @@ -1339,6 +1541,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 1); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartIp4AddrResolverInfo.HostNameMatches("shield")); @@ -1370,6 +1574,8 @@ void TestProxyBasic(void) VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 1); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 1); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopIp4AddrResolverInfo.HostNameMatches("shield")); VerifyOrQuit(sStopIp4AddrResolverInfo.mCallback == sStartIp4AddrResolverInfo.mCallback); @@ -1396,6 +1602,162 @@ void TestProxyBasic(void) VerifyOrQuit(sResolveAddressInfo.mHostAddresses[0] == address); + Log("--------------------------------------------------------------------------------------------"); + + ResetPlatDnssdApiInfo(); + sQueryRecordInfo.Reset(); + + Log("QueryRecord()"); + SuccessOrQuit(dnsClient->QueryRecord(Dns::ResourceRecord::kTypeKey, "shield", "default.service.arpa.", + RecordCallback, sInstance)); + AdvanceTime(10); + + // Check that a record querier is started + + VerifyOrQuit(sStartBrowserInfo.mCallCount == 0); + VerifyOrQuit(sStopBrowserInfo.mCallCount == 0); + VerifyOrQuit(sStartSrvResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopSrvResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartTxtResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 1); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); + + VerifyOrQuit(sStartRecordQuerierInfo.NameMatches("shield", nullptr)); + + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 0); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - "); + Log("Invoke Record Querier callback"); + + recordResult.mFirstLabel = "shield"; + recordResult.mNextLabels = nullptr; + recordResult.mRecordType = Dns::ResourceRecord::kTypeKey; + recordResult.mRecordData = kKeyData; + recordResult.mRecordDataLength = sizeof(kKeyData); + recordResult.mTtl = kTtl; + recordResult.mInfraIfIndex = kInfraIfIndex; + + InvokeRecordQuerierCallback(sStartRecordQuerierInfo.mCallback, recordResult); + + AdvanceTime(10); + + // Check that the record querier is stopped + + VerifyOrQuit(sStartBrowserInfo.mCallCount == 0); + VerifyOrQuit(sStopBrowserInfo.mCallCount == 0); + VerifyOrQuit(sStartSrvResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopSrvResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartTxtResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 1); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 1); + + VerifyOrQuit(sStopRecordQuerierInfo.NameMatches("shield", nullptr)); + VerifyOrQuit(sStopRecordQuerierInfo.mCallback == sStartRecordQuerierInfo.mCallback); + + // Check that response is sent to client and validate it + + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 1); + SuccessOrQuit(sQueryRecordInfo.mError); + + VerifyOrQuit(!strcmp(sQueryRecordInfo.mQueryName, "shield.default.service.arpa.")); + VerifyOrQuit(sQueryRecordInfo.mNumRecords == 1); + + VerifyOrQuit(!strcmp(sQueryRecordInfo.mRecords[0].mNameBuffer, "shield.default.service.arpa.")); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordType == Dns::ResourceRecord::kTypeKey); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordLength == sizeof(kKeyData)); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mTtl == kTtl); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mDataBufferSize == sizeof(kKeyData)); + VerifyOrQuit(!memcmp(sQueryRecordInfo.mRecords[0].mDataBuffer, kKeyData, sizeof(kKeyData))); + VerifyOrQuit(MapEnum(sQueryRecordInfo.mRecords[0].mSection) == Dns::Client::RecordInfo::kSectionAnswer); + + Log("--------------------------------------------------------------------------------------------"); + + ResetPlatDnssdApiInfo(); + sQueryRecordInfo.Reset(); + + Log("QueryRecord()"); + SuccessOrQuit(dnsClient->QueryRecord(Dns::ResourceRecord::kTypeKey, "iron.man", + "_avenger._udp.default.service.arpa.", RecordCallback, sInstance)); + AdvanceTime(10); + + // Check that a record querier is started + + VerifyOrQuit(sStartBrowserInfo.mCallCount == 0); + VerifyOrQuit(sStopBrowserInfo.mCallCount == 0); + VerifyOrQuit(sStartSrvResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopSrvResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartTxtResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 1); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); + + VerifyOrQuit(sStartRecordQuerierInfo.NameMatches("iron.man", "_avenger._udp")); + + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 0); + + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - "); + Log("Invoke Record Querier callback"); + + recordResult.mFirstLabel = "iron.man"; + recordResult.mNextLabels = "_avenger._udp"; + recordResult.mRecordType = Dns::ResourceRecord::kTypeKey; + recordResult.mRecordData = kKeyData; + recordResult.mRecordDataLength = sizeof(kKeyData); + recordResult.mTtl = kTtl; + recordResult.mInfraIfIndex = kInfraIfIndex; + + InvokeRecordQuerierCallback(sStartRecordQuerierInfo.mCallback, recordResult); + + AdvanceTime(10); + + // Check that the record querier is stopped + + VerifyOrQuit(sStartBrowserInfo.mCallCount == 0); + VerifyOrQuit(sStopBrowserInfo.mCallCount == 0); + VerifyOrQuit(sStartSrvResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopSrvResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartTxtResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopIp4AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 1); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 1); + + VerifyOrQuit(sStopRecordQuerierInfo.NameMatches("iron.man", "_avenger._udp")); + VerifyOrQuit(sStopRecordQuerierInfo.mCallback == sStartRecordQuerierInfo.mCallback); + + // Check that response is sent to client and validate it + + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 1); + SuccessOrQuit(sQueryRecordInfo.mError); + + VerifyOrQuit(!strcmp(sQueryRecordInfo.mQueryName, "iron.man._avenger._udp.default.service.arpa.")); + VerifyOrQuit(sQueryRecordInfo.mNumRecords == 1); + + VerifyOrQuit(!strcmp(sQueryRecordInfo.mRecords[0].mNameBuffer, "iron.man._avenger._udp.default.service.arpa.")); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordType == Dns::ResourceRecord::kTypeKey); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordLength == sizeof(kKeyData)); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mTtl == kTtl); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mDataBufferSize == sizeof(kKeyData)); + VerifyOrQuit(!memcmp(sQueryRecordInfo.mRecords[0].mDataBuffer, kKeyData, sizeof(kKeyData))); + VerifyOrQuit(MapEnum(sQueryRecordInfo.mRecords[0].mSection) == Dns::Client::RecordInfo::kSectionAnswer); + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - "); Log("Stop DNS-SD server"); @@ -1480,6 +1842,8 @@ void TestProxySubtypeBrowse(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartBrowserInfo.ServiceTypeMatches("_avenger._udp")); VerifyOrQuit(sStartBrowserInfo.SubTypeMatches("_god")); @@ -1509,6 +1873,8 @@ void TestProxySubtypeBrowse(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopBrowserInfo.ServiceTypeMatches("_avenger._udp")); VerifyOrQuit(sStopBrowserInfo.SubTypeMatches("_god")); @@ -1544,6 +1910,8 @@ void TestProxySubtypeBrowse(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopSrvResolverInfo.ServiceTypeMatches("_avenger._udp")); VerifyOrQuit(sStopSrvResolverInfo.ServiceInstanceMatches("thor")); @@ -1575,6 +1943,8 @@ void TestProxySubtypeBrowse(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 1); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 1); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopTxtResolverInfo.ServiceTypeMatches("_avenger._udp")); VerifyOrQuit(sStopTxtResolverInfo.ServiceInstanceMatches("thor")); @@ -1607,6 +1977,8 @@ void TestProxySubtypeBrowse(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 1); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 1); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 1); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.HostNameMatches("asgard")); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallback == sStartIp6AddrResolverInfo.mCallback); @@ -1716,6 +2088,8 @@ void TestProxyTimeout(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartBrowserInfo.ServiceTypeMatches("_game._ps5")); @@ -1754,6 +2128,8 @@ void TestProxyTimeout(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartBrowserInfo.ServiceTypeMatches("_avenger._udp")); @@ -1780,6 +2156,8 @@ void TestProxyTimeout(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopBrowserInfo.ServiceTypeMatches("_avenger._udp")); @@ -1806,6 +2184,8 @@ void TestProxyTimeout(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStopBrowserInfo.ServiceTypeMatches("_avenger._udp")); @@ -1834,6 +2214,8 @@ void TestProxyTimeout(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartBrowserInfo.ServiceTypeMatches("_avenger._udp")); @@ -1850,6 +2232,8 @@ void TestProxyTimeout(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartBrowserInfo.ServiceTypeMatches("_game._udp")); @@ -1867,6 +2251,8 @@ void TestProxyTimeout(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 0); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartSrvResolverInfo.ServiceTypeMatches("_avenger._udp")); VerifyOrQuit(sStartSrvResolverInfo.ServiceInstanceMatches("wanda")); @@ -1884,9 +2270,30 @@ void TestProxyTimeout(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 1); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.HostNameMatches("earth")); + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - "); + Log("QueryRecord()"); + SuccessOrQuit(dnsClient->QueryRecord(Dns::ResourceRecord::kTypeKey, "iron.man", + "_avenger._udp.default.service.arpa.", RecordCallback, sInstance)); + AdvanceTime(10); + + VerifyOrQuit(sStartBrowserInfo.mCallCount == 2); + VerifyOrQuit(sStopBrowserInfo.mCallCount == 0); + VerifyOrQuit(sStartSrvResolverInfo.mCallCount == 1); + VerifyOrQuit(sStopSrvResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartTxtResolverInfo.mCallCount == 0); + VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 1); + VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 0); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 1); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); + + VerifyOrQuit(sStartRecordQuerierInfo.NameMatches("iron.man", "_avenger._udp")); + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - "); Log("Wait for timeout for all requests"); @@ -1911,10 +2318,13 @@ void TestProxyTimeout(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 1); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 1); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 1); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 1); VerifyOrQuit(sStopSrvResolverInfo.ServiceTypeMatches("_avenger._udp")); VerifyOrQuit(sStopSrvResolverInfo.ServiceInstanceMatches("wanda")); VerifyOrQuit(sStopIp6AddrResolverInfo.HostNameMatches("earth")); + VerifyOrQuit(sStopRecordQuerierInfo.NameMatches("iron.man", "_avenger._udp")); Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - "); Log("Stop DNS-SD server"); @@ -2634,6 +3044,7 @@ void TestProxyInvokeCallbackFromStartApi(void) static constexpr uint32_t kTtl = 300; const uint8_t kTxtData[] = {3, 'A', '=', '1', 0}; + const uint8_t kKeyData[] = {0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99}; Srp::Server *srpServer; Srp::Client *srpClient; @@ -2644,6 +3055,7 @@ void TestProxyInvokeCallbackFromStartApi(void) Dnssd::TxtResult txtResult; Dnssd::AddressResult ip6AddrrResult; Dnssd::AddressAndTtl addressAndTtl[2]; + Dnssd::RecordResult recordResult; Log("--------------------------------------------------------------------------------------------"); Log("TestProxyInvokeCallbackFromStartApi"); @@ -2687,6 +3099,7 @@ void TestProxyInvokeCallbackFromStartApi(void) sInvokeOnStart.mSrvResult = &srvResult; sInvokeOnStart.mTxtResult = &txtResult; sInvokeOnStart.mIp6AddrResult = &ip6AddrrResult; + sInvokeOnStart.mRecordResult = &recordResult; browseResult.mServiceType = "_guardian._glaxy"; browseResult.mSubTypeLabel = nullptr; @@ -2718,6 +3131,14 @@ void TestProxyInvokeCallbackFromStartApi(void) ip6AddrrResult.mAddresses = addressAndTtl; ip6AddrrResult.mAddressesLength = 2; + recordResult.mFirstLabel = "drax"; + recordResult.mNextLabels = "_guardian._glaxy"; + recordResult.mRecordType = Dns::ResourceRecord::kTypeKey; + recordResult.mRecordData = kKeyData; + recordResult.mRecordDataLength = sizeof(kKeyData); + recordResult.mTtl = kTtl; + recordResult.mInfraIfIndex = kInfraIfIndex; + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - "); sBrowseInfo.Reset(); @@ -2736,6 +3157,8 @@ void TestProxyInvokeCallbackFromStartApi(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 1); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 1); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 1); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartBrowserInfo.ServiceTypeMatches("_guardian._glaxy")); VerifyOrQuit(sStopBrowserInfo.ServiceTypeMatches("_guardian._glaxy")); @@ -2786,6 +3209,8 @@ void TestProxyInvokeCallbackFromStartApi(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 2); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 2); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 2); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartSrvResolverInfo.ServiceTypeMatches("_guardian._glaxy")); VerifyOrQuit(sStartSrvResolverInfo.ServiceInstanceMatches("mantis")); @@ -2836,6 +3261,8 @@ void TestProxyInvokeCallbackFromStartApi(void) VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 2); VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 3); VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 3); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 0); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 0); VerifyOrQuit(sStartIp6AddrResolverInfo.HostNameMatches("nova")); VerifyOrQuit(sStopIp6AddrResolverInfo.HostNameMatches("nova")); @@ -2854,6 +3281,46 @@ void TestProxyInvokeCallbackFromStartApi(void) sResolveAddressInfo.mHostAddresses[index] == AsCoreType(&addressAndTtl[1].mAddress)); } + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - "); + + sQueryRecordInfo.Reset(); + Log("QueryRecord()"); + SuccessOrQuit(dnsClient->QueryRecord(Dns::ResourceRecord::kTypeKey, "drax", + "_guardian._glaxy.default.service.arpa.", RecordCallback, sInstance)); + AdvanceTime(10); + + // Check that the record querier is started and then stopped + + VerifyOrQuit(sStartBrowserInfo.mCallCount == 1); + VerifyOrQuit(sStopBrowserInfo.mCallCount == 1); + VerifyOrQuit(sStartSrvResolverInfo.mCallCount == 2); + VerifyOrQuit(sStopSrvResolverInfo.mCallCount == 2); + VerifyOrQuit(sStartTxtResolverInfo.mCallCount == 2); + VerifyOrQuit(sStopTxtResolverInfo.mCallCount == 2); + VerifyOrQuit(sStartIp6AddrResolverInfo.mCallCount == 3); + VerifyOrQuit(sStopIp6AddrResolverInfo.mCallCount == 3); + VerifyOrQuit(sStartRecordQuerierInfo.mCallCount == 1); + VerifyOrQuit(sStopRecordQuerierInfo.mCallCount == 1); + + VerifyOrQuit(sStartRecordQuerierInfo.NameMatches("drax", "_guardian._glaxy")); + VerifyOrQuit(sStopRecordQuerierInfo.NameMatches("drax", "_guardian._glaxy")); + + // Validate the query record response on client + + VerifyOrQuit(sQueryRecordInfo.mCallbackCount == 1); + SuccessOrQuit(sQueryRecordInfo.mError); + + VerifyOrQuit(!strcmp(sQueryRecordInfo.mQueryName, "drax._guardian._glaxy.default.service.arpa.")); + VerifyOrQuit(sQueryRecordInfo.mNumRecords == 1); + + VerifyOrQuit(!strcmp(sQueryRecordInfo.mRecords[0].mNameBuffer, "drax._guardian._glaxy.default.service.arpa.")); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordType == Dns::ResourceRecord::kTypeKey); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mRecordLength == sizeof(kKeyData)); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mTtl == kTtl); + VerifyOrQuit(sQueryRecordInfo.mRecords[0].mDataBufferSize == sizeof(kKeyData)); + VerifyOrQuit(!memcmp(sQueryRecordInfo.mRecords[0].mDataBuffer, kKeyData, sizeof(kKeyData))); + VerifyOrQuit(MapEnum(sQueryRecordInfo.mRecords[0].mSection) == Dns::Client::RecordInfo::kSectionAnswer); + Log("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - "); Log("Stop DNS-SD server"); diff --git a/tests/unit/test_platform.cpp b/tests/unit/test_platform.cpp index f5a246f22..b075acacf 100644 --- a/tests/unit/test_platform.cpp +++ b/tests/unit/test_platform.cpp @@ -963,6 +963,18 @@ OT_TOOL_WEAK void otPlatDnssdStopIp4AddressResolver(otInstance *aInstance, const OT_UNUSED_VARIABLE(aResolver); } +OT_TOOL_WEAK void otPlatDnssdStartRecordQuerier(otInstance *aInstance, const otPlatDnssdRecordQuerier *aQuerier) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aQuerier); +} + +OT_TOOL_WEAK void otPlatDnssdStopRecordQuerier(otInstance *aInstance, const otPlatDnssdRecordQuerier *aQuerier) +{ + OT_UNUSED_VARIABLE(aInstance); + OT_UNUSED_VARIABLE(aQuerier); +} + #endif // OPENTHREAD_CONFIG_PLATFORM_DNSSD_ENABLE #if OPENTHREAD_CONFIG_PLATFORM_LOG_CRASH_DUMP_ENABLE