Return none for Getter if key does not exist (#233)

This commit is contained in:
Leighton Chen
2020-12-08 11:22:38 -05:00
committed by GitHub
parent 3b48a38948
commit 3eb27ca466
9 changed files with 114 additions and 14 deletions

View File

@ -6,8 +6,7 @@ on:
- 'release/*'
pull_request:
env:
CORE_REPO_SHA: f69e12fba8d0afd587dd21adbedfe751153aa73c
CORE_REPO_SHA: master
jobs:
build:

View File

@ -2,6 +2,9 @@
## Unreleased
- Return `None` for `CarrierGetter` if key not found
([#1374](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/233))
## Version 0.12b0
Released 2020-08-14

View File

@ -34,7 +34,9 @@ from opentelemetry.trace.status import Status, StatusCode
class CarrierGetter(DictGetter):
def get(self, carrier: dict, key: str) -> typing.List[str]:
def get(
self, carrier: dict, key: str
) -> typing.Optional[typing.List[str]]:
"""Getter implementation to retrieve a HTTP header value from the ASGI
scope.
@ -43,14 +45,17 @@ class CarrierGetter(DictGetter):
key: header name in scope
Returns:
A list with a single string with the header value if it exists,
else an empty list.
else None.
"""
headers = carrier.get("headers")
return [
decoded = [
_value.decode("utf8")
for (_key, _value) in headers
if _key.decode("utf8") == key
]
if not decoded:
return None
return decoded
carrier_getter = CarrierGetter()
@ -82,11 +87,12 @@ def collect_request_attributes(scope):
http_method = scope.get("method")
if http_method:
result["http.method"] = http_method
http_host_value = ",".join(carrier_getter.get(scope, "host"))
if http_host_value:
result["http.server_name"] = http_host_value
http_host_value_list = carrier_getter.get(scope, "host")
if http_host_value_list:
result["http.server_name"] = ",".join(http_host_value_list)
http_user_agent = carrier_getter.get(scope, "user-agent")
if len(http_user_agent) > 0:
if http_user_agent:
result["http.user_agent"] = http_user_agent[0]
if "client" in scope and scope["client"] is not None:

View File

@ -164,7 +164,7 @@ class TestAsgiApplication(AsgiTestBase):
outputs = self.get_all_output()
self.validate_outputs(outputs)
def test_wsgi_not_recording(self):
def test_asgi_not_recording(self):
mock_tracer = mock.Mock()
mock_span = mock.Mock()
mock_span.is_recording.return_value = False
@ -312,8 +312,12 @@ class TestAsgiAttributes(unittest.TestCase):
def test_request_attributes(self):
self.scope["query_string"] = b"foo=bar"
headers = []
headers.append(("host".encode("utf8"), "test".encode("utf8")))
self.scope["headers"] = headers
attrs = otel_asgi.collect_request_attributes(self.scope)
self.assertDictEqual(
attrs,
{
@ -324,6 +328,7 @@ class TestAsgiAttributes(unittest.TestCase):
"http.url": "http://127.0.0.1/?foo=bar",
"host.port": 80,
"http.scheme": "http",
"http.server_name": "test",
"http.flavor": "1.0",
"net.peer.ip": "127.0.0.1",
"net.peer.port": 32767,

View File

@ -80,7 +80,9 @@ _MESSAGE_ID_ATTRIBUTE_NAME = "messaging.message_id"
class CarrierGetter(DictGetter):
def get(self, carrier, key):
value = getattr(carrier, key, [])
value = getattr(carrier, key, None)
if value is None:
return None
if isinstance(value, str) or not isinstance(value, Iterable):
value = (value,)
return value

View File

@ -0,0 +1,44 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase, mock
from opentelemetry.instrumentation.celery import CarrierGetter
class TestCarrierGetter(TestCase):
def test_get_none(self):
getter = CarrierGetter()
carrier = {}
val = getter.get(carrier, "test")
self.assertIsNone(val)
def test_get_str(self):
mock_obj = mock.Mock()
getter = CarrierGetter()
mock_obj.test = "val"
val = getter.get(mock_obj, "test")
self.assertEqual(val, ("val",))
def test_get_iter(self):
mock_obj = mock.Mock()
getter = CarrierGetter()
mock_obj.test = ["val"]
val = getter.get(mock_obj, "test")
self.assertEqual(val, ["val"])
def test_keys(self):
getter = CarrierGetter()
keys = getter.keys({})
self.assertEqual(keys, [])

View File

@ -2,6 +2,9 @@
## Unreleased
- Return `None` for `CarrierGetter` if key not found
([#1374](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/233))
## Version 0.13b0
Released 2020-09-17

View File

@ -68,7 +68,9 @@ _HTTP_VERSION_PREFIX = "HTTP/"
class CarrierGetter(DictGetter):
def get(self, carrier: dict, key: str) -> typing.List[str]:
def get(
self, carrier: dict, key: str
) -> typing.Optional[typing.List[str]]:
"""Getter implementation to retrieve a HTTP header value from the
PEP3333-conforming WSGI environ
@ -77,13 +79,13 @@ class CarrierGetter(DictGetter):
key: header name in environ object
Returns:
A list with a single string with the header value if it exists,
else an empty list.
else None.
"""
environ_key = "HTTP_" + key.upper().replace("-", "_")
value = carrier.get(environ_key)
if value is not None:
return [value]
return []
return None
def keys(self, carrier):
return []

View File

@ -0,0 +1,36 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase, mock
from opentelemetry.instrumentation.wsgi import CarrierGetter
class TestCarrierGetter(TestCase):
def test_get_none(self):
getter = CarrierGetter()
carrier = {}
val = getter.get(carrier, "test")
self.assertIsNone(val)
def test_get_(self):
getter = CarrierGetter()
carrier = {"HTTP_TEST_KEY": "val"}
val = getter.get(carrier, "test-key")
self.assertEqual(val, ["val"])
def test_keys(self):
getter = CarrierGetter()
keys = getter.keys({})
self.assertEqual(keys, [])