diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/setup.cfg b/sdk-extension/opentelemetry-sdk-extension-aws/setup.cfg index 4aa890ca5..6cfbe7cb6 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/setup.cfg +++ b/sdk-extension/opentelemetry-sdk-extension-aws/setup.cfg @@ -39,7 +39,7 @@ package_dir= =src packages=find_namespace: install_requires = - opentelemetry-api == 0.15.b0 + opentelemetry-api == 0.16.dev0 [options.entry_points] opentelemetry_propagator = @@ -47,7 +47,7 @@ opentelemetry_propagator = [options.extras_require] test = - opentelemetry-test == 0.15.b0 + opentelemetry-test == 0.16.dev0 [options.packages.find] where = src diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/propagation/aws_xray_format.py b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/propagation/aws_xray_format.py index 17ef67ed9..e544bd685 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/propagation/aws_xray_format.py +++ b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/propagation/aws_xray_format.py @@ -55,10 +55,6 @@ class AwsXRayFormat(TextMapPropagator): IS_SAMPLED = "1" NOT_SAMPLED = "0" - # pylint: disable=too-many-locals - # pylint: disable=too-many-return-statements - # pylint: disable=too-many-branches - # pylint: disable=too-many-statements def extract( self, getter: Getter[TextMapPropagatorT], @@ -80,136 +76,14 @@ class AwsXRayFormat(TextMapPropagator): trace.INVALID_SPAN, context=context ) - trace_id = trace.INVALID_TRACE_ID - span_id = trace.INVALID_SPAN_ID - sampled = False + trace_id, span_id, sampled, err = self.extract_span_properties( + trace_header + ) - next_kv_pair_start = 0 - - while next_kv_pair_start < len(trace_header): - try: - kv_pair_delimiter_index = trace_header.index( - self.KV_PAIR_DELIMITER, next_kv_pair_start - ) - kv_pair_subset = trace_header[ - next_kv_pair_start:kv_pair_delimiter_index - ] - next_kv_pair_start = kv_pair_delimiter_index + 1 - except ValueError: - kv_pair_subset = trace_header[next_kv_pair_start:] - next_kv_pair_start = len(trace_header) - - stripped_kv_pair = kv_pair_subset.strip() - - try: - key_and_value_delimiter_index = stripped_kv_pair.index( - self.KEY_AND_VALUE_DELIMITER - ) - except ValueError: - _logger.error( - ( - "Error parsing X-Ray trace header. Invalid key value pair: %s. Returning INVALID span context.", - kv_pair_subset, - ) - ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) - - value = stripped_kv_pair[key_and_value_delimiter_index + 1 :] - - if stripped_kv_pair.startswith(self.TRACE_ID_KEY): - if ( - len(value) != self.TRACE_ID_LENGTH - or not value.startswith(self.TRACE_ID_VERSION) - or value[self.TRACE_ID_DELIMITER_INDEX_1] - != self.TRACE_ID_DELIMITER - or value[self.TRACE_ID_DELIMITER_INDEX_2] - != self.TRACE_ID_DELIMITER - ): - _logger.error( - ( - "Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", - self.TRACE_HEADER_KEY, - trace_header, - ) - ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) - - timestamp_subset = value[ - self.TRACE_ID_DELIMITER_INDEX_1 - + 1 : self.TRACE_ID_DELIMITER_INDEX_2 - ] - unique_id_subset = value[ - self.TRACE_ID_DELIMITER_INDEX_2 + 1 : self.TRACE_ID_LENGTH - ] - try: - trace_id = int(timestamp_subset + unique_id_subset, 16) - except ValueError: - _logger.error( - ( - "Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", - self.TRACE_HEADER_KEY, - trace_header, - ) - ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) - elif stripped_kv_pair.startswith(self.PARENT_ID_KEY): - if len(value) != self.PARENT_ID_LENGTH: - _logger.error( - ( - "Invalid ParentId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", - self.TRACE_HEADER_KEY, - trace_header, - ) - ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) - - try: - span_id = int(value, 16) - except ValueError: - _logger.error( - ( - "Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", - self.TRACE_HEADER_KEY, - trace_header, - ) - ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) - elif stripped_kv_pair.startswith(self.SAMPLED_FLAG_KEY): - is_sampled_flag_valid = True - - if len(value) != self.SAMPLED_FLAG_LENGTH: - is_sampled_flag_valid = False - - if is_sampled_flag_valid: - sampled_flag = value[0] - if sampled_flag == self.IS_SAMPLED: - sampled = True - elif sampled_flag == self.NOT_SAMPLED: - sampled = False - else: - is_sampled_flag_valid = False - - if not is_sampled_flag_valid: - _logger.error( - ( - "Invalid Sampling flag in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", - self.TRACE_HEADER_KEY, - trace_header, - ) - ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) + if err is not None: + return trace.set_span_in_context( + trace.INVALID_SPAN, context=context + ) options = 0 if sampled: @@ -235,6 +109,131 @@ class AwsXRayFormat(TextMapPropagator): trace.DefaultSpan(span_context), context=context ) + def extract_span_properties(self, trace_header): + trace_id = trace.INVALID_TRACE_ID + span_id = trace.INVALID_SPAN_ID + sampled = False + + extract_err = None + + for kv_pair_str in trace_header.split(self.KV_PAIR_DELIMITER): + if extract_err: + break + + try: + key_str, value_str = kv_pair_str.split( + self.KEY_AND_VALUE_DELIMITER + ) + key, value = key_str.strip(), value_str.strip() + except ValueError: + _logger.error( + ( + "Error parsing X-Ray trace header. Invalid key value pair: %s. Returning INVALID span context.", + kv_pair_str, + ) + ) + return trace_id, span_id, sampled, extract_err + + if key == self.TRACE_ID_KEY: + if not self.validate_trace_id(value): + _logger.error( + ( + "Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", + self.TRACE_HEADER_KEY, + trace_header, + ) + ) + extract_err = True + break + + try: + trace_id = self.parse_trace_id(value) + except ValueError: + _logger.error( + ( + "Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", + self.TRACE_HEADER_KEY, + trace_header, + ) + ) + extract_err = True + elif key == self.PARENT_ID_KEY: + if not self.validate_span_id(value): + _logger.error( + ( + "Invalid ParentId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", + self.TRACE_HEADER_KEY, + trace_header, + ) + ) + extract_err = True + break + + try: + span_id = AwsXRayFormat.parse_span_id(value) + except ValueError: + _logger.error( + ( + "Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", + self.TRACE_HEADER_KEY, + trace_header, + ) + ) + extract_err = True + elif key == self.SAMPLED_FLAG_KEY: + if not self.validate_sampled_flag(value): + _logger.error( + ( + "Invalid Sampling flag in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", + self.TRACE_HEADER_KEY, + trace_header, + ) + ) + extract_err = True + break + + sampled = self.parse_sampled_flag(value) + + return trace_id, span_id, sampled, extract_err + + def validate_trace_id(self, trace_id_str): + return ( + len(trace_id_str) == self.TRACE_ID_LENGTH + and trace_id_str.startswith(self.TRACE_ID_VERSION) + and trace_id_str[self.TRACE_ID_DELIMITER_INDEX_1] + == self.TRACE_ID_DELIMITER + and trace_id_str[self.TRACE_ID_DELIMITER_INDEX_2] + == self.TRACE_ID_DELIMITER + ) + + def parse_trace_id(self, trace_id_str): + timestamp_subset = trace_id_str[ + self.TRACE_ID_DELIMITER_INDEX_1 + + 1 : self.TRACE_ID_DELIMITER_INDEX_2 + ] + unique_id_subset = trace_id_str[ + self.TRACE_ID_DELIMITER_INDEX_2 + 1 : self.TRACE_ID_LENGTH + ] + return int(timestamp_subset + unique_id_subset, 16) + + def validate_span_id(self, span_id_str): + return len(span_id_str) == self.PARENT_ID_LENGTH + + @staticmethod + def parse_span_id(span_id_str): + return int(span_id_str, 16) + + def validate_sampled_flag(self, sampled_flag_str): + return len( + sampled_flag_str + ) == self.SAMPLED_FLAG_LENGTH and sampled_flag_str in ( + self.IS_SAMPLED, + self.NOT_SAMPLED, + ) + + def parse_sampled_flag(self, sampled_flag_str): + return sampled_flag_str[0] == self.IS_SAMPLED + def inject( self, set_in_carrier: Setter[TextMapPropagatorT], diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/propagation/test_aws_xray_format.py b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/propagation/test_aws_xray_format.py index fc0946bab..0ea84ef21 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/propagation/test_aws_xray_format.py +++ b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/propagation/test_aws_xray_format.py @@ -247,6 +247,41 @@ class AwsXRayPropagatorTest(unittest.TestCase): get_extracted_span_context(build_test_context()), ) + def test_extract_with_extra_whitespace(self): + default_xray_trace_header_dict = build_dict_with_xray_trace_header() + trace_header_components = default_xray_trace_header_dict[ + AwsXRayFormat.TRACE_HEADER_KEY + ].split(AwsXRayFormat.KV_PAIR_DELIMITER) + xray_trace_header_dict_with_extra_whitespace = CaseInsensitiveDict( + { + AwsXRayFormat.TRACE_HEADER_KEY: AwsXRayFormat.KV_PAIR_DELIMITER.join( + [ + AwsXRayFormat.KEY_AND_VALUE_DELIMITER.join( + [ + " " + key + " ", + " " + value + " ", + ] + ) + for kv_pair_str in trace_header_components + for key, value in [ + kv_pair_str.split( + AwsXRayFormat.KEY_AND_VALUE_DELIMITER + ) + ] + ] + ) + } + ) + actual_context_encompassing_extracted = AwsXRayPropagatorTest.XRAY_PROPAGATOR.extract( + AwsXRayPropagatorTest.carrier_getter, + xray_trace_header_dict_with_extra_whitespace, + ) + + self.assertEqual( + get_extracted_span_context(actual_context_encompassing_extracted), + get_extracted_span_context(build_test_context()), + ) + def test_extract_invalid_xray_trace_header(self): actual_context_encompassing_extracted = AwsXRayPropagatorTest.XRAY_PROPAGATOR.extract( AwsXRayPropagatorTest.carrier_getter,