Use an exception to catch inability to parse

This commit is contained in:
Nathaniel Ruiz Nowell
2020-11-08 12:46:21 -08:00
parent ccb7a83922
commit 26db369ec9
5 changed files with 50 additions and 82 deletions

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime
import random import random
import time
from opentelemetry import trace from opentelemetry import trace
@ -28,10 +28,13 @@ class AwsXRayIdsGenerator(trace.IdsGenerator):
See: https://docs.aws.amazon.com/xray/latest/devguide/xray-api-sendingdata.html#xray-api-traceids See: https://docs.aws.amazon.com/xray/latest/devguide/xray-api-sendingdata.html#xray-api-traceids
""" """
def generate_span_id(self) -> int: random_ids_generator = trace.RandomIdsGenerator()
return trace.RandomIdsGenerator().generate_span_id()
def generate_trace_id(self) -> int: def generate_span_id(self) -> int:
trace_time = int(datetime.datetime.utcnow().timestamp()) return self.random_ids_generator.generate_span_id()
@staticmethod
def generate_trace_id() -> int:
trace_time = int(time.time())
trace_identifier = random.getrandbits(96) trace_identifier = random.getrandbits(96)
return (trace_time << 96) + trace_identifier return (trace_time << 96) + trace_identifier

View File

@ -27,6 +27,12 @@ from opentelemetry.trace.propagation.textmap import (
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class AwsParseTraceHeaderError(Exception):
def __init__(self, message):
super().__init__()
self.message = message
class AwsXRayFormat(TextMapPropagator): class AwsXRayFormat(TextMapPropagator):
"""Propagator for the AWS X-Ray Trace Header propagation protocol. """Propagator for the AWS X-Ray Trace Header propagation protocol.
@ -76,11 +82,12 @@ class AwsXRayFormat(TextMapPropagator):
trace.INVALID_SPAN, context=context trace.INVALID_SPAN, context=context
) )
trace_id, span_id, sampled, err = self.extract_span_properties( try:
trace_header trace_id, span_id, sampled = self._extract_span_properties(
) trace_header
)
if err is not None: except AwsParseTraceHeaderError as err:
_logger.debug(err.message)
return trace.set_span_in_context( return trace.set_span_in_context(
trace.INVALID_SPAN, context=context trace.INVALID_SPAN, context=context
) )
@ -98,7 +105,7 @@ class AwsXRayFormat(TextMapPropagator):
) )
if not span_context.is_valid: if not span_context.is_valid:
_logger.error( _logger.debug(
"Invalid Span Extracted. Insertting INVALID span into provided context." "Invalid Span Extracted. Insertting INVALID span into provided context."
) )
return trace.set_span_in_context( return trace.set_span_in_context(
@ -109,94 +116,79 @@ class AwsXRayFormat(TextMapPropagator):
trace.DefaultSpan(span_context), context=context trace.DefaultSpan(span_context), context=context
) )
def extract_span_properties(self, trace_header): def _extract_span_properties(self, trace_header):
trace_id = trace.INVALID_TRACE_ID trace_id = trace.INVALID_TRACE_ID
span_id = trace.INVALID_SPAN_ID span_id = trace.INVALID_SPAN_ID
sampled = False sampled = False
extract_err = None
for kv_pair_str in trace_header.split(self.KV_PAIR_DELIMITER): for kv_pair_str in trace_header.split(self.KV_PAIR_DELIMITER):
if extract_err:
break
try: try:
key_str, value_str = kv_pair_str.split( key_str, value_str = kv_pair_str.split(
self.KEY_AND_VALUE_DELIMITER self.KEY_AND_VALUE_DELIMITER
) )
key, value = key_str.strip(), value_str.strip() key, value = key_str.strip(), value_str.strip()
except ValueError: except ValueError as ex:
_logger.error( raise AwsParseTraceHeaderError(
( (
"Error parsing X-Ray trace header. Invalid key value pair: %s. Returning INVALID span context.", "Error parsing X-Ray trace header. Invalid key value pair: %s. Returning INVALID span context.",
kv_pair_str, kv_pair_str,
) )
) ) from ex
return trace_id, span_id, sampled, extract_err
if key == self.TRACE_ID_KEY: if key == self.TRACE_ID_KEY:
if not self.validate_trace_id(value): if not self._validate_trace_id(value):
_logger.error( raise AwsParseTraceHeaderError(
( (
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", "Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY, self.TRACE_HEADER_KEY,
trace_header, trace_header,
) )
) )
extract_err = True
break
try: try:
trace_id = self.parse_trace_id(value) trace_id = self._parse_trace_id(value)
except ValueError: except ValueError as ex:
_logger.error( raise AwsParseTraceHeaderError(
( (
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", "Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY, self.TRACE_HEADER_KEY,
trace_header, trace_header,
) )
) ) from ex
extract_err = True
elif key == self.PARENT_ID_KEY: elif key == self.PARENT_ID_KEY:
if not self.validate_span_id(value): if not self._validate_span_id(value):
_logger.error( raise AwsParseTraceHeaderError(
( (
"Invalid ParentId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", "Invalid ParentId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY, self.TRACE_HEADER_KEY,
trace_header, trace_header,
) )
) )
extract_err = True
break
try: try:
span_id = AwsXRayFormat.parse_span_id(value) span_id = AwsXRayFormat._parse_span_id(value)
except ValueError: except ValueError as ex:
_logger.error( raise AwsParseTraceHeaderError(
( (
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", "Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY, self.TRACE_HEADER_KEY,
trace_header, trace_header,
) )
) ) from ex
extract_err = True
elif key == self.SAMPLED_FLAG_KEY: elif key == self.SAMPLED_FLAG_KEY:
if not self.validate_sampled_flag(value): if not self._validate_sampled_flag(value):
_logger.error( raise AwsParseTraceHeaderError(
( (
"Invalid Sampling flag in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", "Invalid Sampling flag in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.",
self.TRACE_HEADER_KEY, self.TRACE_HEADER_KEY,
trace_header, trace_header,
) )
) )
extract_err = True
break
sampled = self.parse_sampled_flag(value) sampled = self._parse_sampled_flag(value)
return trace_id, span_id, sampled, extract_err return trace_id, span_id, sampled
def validate_trace_id(self, trace_id_str): def _validate_trace_id(self, trace_id_str):
return ( return (
len(trace_id_str) == self.TRACE_ID_LENGTH len(trace_id_str) == self.TRACE_ID_LENGTH
and trace_id_str.startswith(self.TRACE_ID_VERSION) and trace_id_str.startswith(self.TRACE_ID_VERSION)
@ -206,7 +198,7 @@ class AwsXRayFormat(TextMapPropagator):
== self.TRACE_ID_DELIMITER == self.TRACE_ID_DELIMITER
) )
def parse_trace_id(self, trace_id_str): def _parse_trace_id(self, trace_id_str):
timestamp_subset = trace_id_str[ timestamp_subset = trace_id_str[
self.TRACE_ID_DELIMITER_INDEX_1 self.TRACE_ID_DELIMITER_INDEX_1
+ 1 : self.TRACE_ID_DELIMITER_INDEX_2 + 1 : self.TRACE_ID_DELIMITER_INDEX_2
@ -216,14 +208,14 @@ class AwsXRayFormat(TextMapPropagator):
] ]
return int(timestamp_subset + unique_id_subset, 16) return int(timestamp_subset + unique_id_subset, 16)
def validate_span_id(self, span_id_str): def _validate_span_id(self, span_id_str):
return len(span_id_str) == self.PARENT_ID_LENGTH return len(span_id_str) == self.PARENT_ID_LENGTH
@staticmethod @staticmethod
def parse_span_id(span_id_str): def _parse_span_id(span_id_str):
return int(span_id_str, 16) return int(span_id_str, 16)
def validate_sampled_flag(self, sampled_flag_str): def _validate_sampled_flag(self, sampled_flag_str):
return len( return len(
sampled_flag_str sampled_flag_str
) == self.SAMPLED_FLAG_LENGTH and sampled_flag_str in ( ) == self.SAMPLED_FLAG_LENGTH and sampled_flag_str in (
@ -231,7 +223,7 @@ class AwsXRayFormat(TextMapPropagator):
self.NOT_SAMPLED, self.NOT_SAMPLED,
) )
def parse_sampled_flag(self, sampled_flag_str): def _parse_sampled_flag(self, sampled_flag_str):
return sampled_flag_str[0] == self.IS_SAMPLED return sampled_flag_str[0] == self.IS_SAMPLED
def inject( def inject(

View File

@ -1,13 +0,0 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -1,13 +0,0 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
import time
import unittest import unittest
from opentelemetry.sdk.extension.aws.trace import AwsXRayIdsGenerator from opentelemetry.sdk.extension.aws.trace import AwsXRayIdsGenerator
@ -33,11 +34,9 @@ class AwsXRayIdsGeneratorTest(unittest.TestCase):
for _ in range(1000): for _ in range(1000):
trace_id = ids_generator.generate_trace_id() trace_id = ids_generator.generate_trace_id()
trace_id_time = trace_id >> 96 trace_id_time = trace_id >> 96
current_time = int(datetime.datetime.utcnow().timestamp()) current_time = int(time.time())
self.assertLessEqual(trace_id_time, current_time) self.assertLessEqual(trace_id_time, current_time)
one_month_ago_time = int( one_month_ago_time = int(
( (datetime.datetime.now() - datetime.timedelta(30)).timestamp()
datetime.datetime.utcnow() - datetime.timedelta(30)
).timestamp()
) )
self.assertGreater(trace_id_time, one_month_ago_time) self.assertGreater(trace_id_time, one_month_ago_time)