Add instrumentor and auto instrumentation support for aiohttp (#1075)

This commit is contained in:
Mario Jonke
2020-10-09 16:18:22 +02:00
committed by alrex
parent 5f100f00a3
commit 8bc6abdbf2
4 changed files with 377 additions and 69 deletions

View File

@ -8,6 +8,8 @@ Released 2020-09-17
- Updating span name to match semantic conventions - Updating span name to match semantic conventions
([#972](https://github.com/open-telemetry/opentelemetry-python/pull/972)) ([#972](https://github.com/open-telemetry/opentelemetry-python/pull/972))
- Add instrumentor and auto instrumentation support for aiohttp
([#1075](https://github.com/open-telemetry/opentelemetry-python/pull/1075))
## Version 0.12b0 ## Version 0.12b0

View File

@ -39,12 +39,17 @@ package_dir=
=src =src
packages=find_namespace: packages=find_namespace:
install_requires = install_requires =
opentelemetry-api >= 0.12.dev0 opentelemetry-api == 0.14.dev0
opentelemetry-instrumentation == 0.14.dev0 opentelemetry-instrumentation == 0.14.dev0
aiohttp ~= 3.0 aiohttp ~= 3.0
wrapt >= 1.0.0, < 2.0.0
[options.packages.find] [options.packages.find]
where = src where = src
[options.extras_require] [options.extras_require]
test = test =
[options.entry_points]
opentelemetry_instrumentor =
aiohttp-client = opentelemetry.instrumentation.aiohttp_client:AioHttpClientInstrumentor

View File

@ -18,44 +18,73 @@ requests made by the aiohttp client library.
Usage Usage
----- -----
Explicitly instrumenting a single client session:
.. code:: python .. code:: python
import aiohttp import aiohttp
from opentelemetry.instrumentation.aiohttp_client import ( from opentelemetry.instrumentation.aiohttp_client import (
create_trace_config, create_trace_config,
url_path_span_name url_path_span_name
) )
import yarl import yarl
def strip_query_params(url: yarl.URL) -> str: def strip_query_params(url: yarl.URL) -> str:
return str(url.with_query(None)) return str(url.with_query(None))
async with aiohttp.ClientSession(trace_configs=[create_trace_config( async with aiohttp.ClientSession(trace_configs=[create_trace_config(
# Remove all query params from the URL attribute on the span. # Remove all query params from the URL attribute on the span.
url_filter=strip_query_params, url_filter=strip_query_params,
# Use the URL's path as the span name. # Use the URL's path as the span name.
span_name=url_path_span_name span_name=url_path_span_name
)]) as session: )]) as session:
async with session.get(url) as response: async with session.get(url) as response:
await response.text() await response.text()
Instrumenting all client sessions:
.. code:: python
import aiohttp
from opentelemetry.instrumentation.aiohttp_client import (
AioHttpClientInstrumentor
)
# Enable instrumentation
AioHttpClientInstrumentor().instrument()
# Create a session and make an HTTP get request
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
await response.text()
API
---
""" """
import contextlib
import socket import socket
import types import types
import typing import typing
import aiohttp import aiohttp
import wrapt
from opentelemetry import context as context_api from opentelemetry import context as context_api
from opentelemetry import propagators, trace from opentelemetry import propagators, trace
from opentelemetry.instrumentation.aiohttp_client.version import __version__ from opentelemetry.instrumentation.aiohttp_client.version import __version__
from opentelemetry.instrumentation.utils import http_status_to_canonical_code from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.trace import SpanKind from opentelemetry.instrumentation.utils import (
http_status_to_canonical_code,
unwrap,
)
from opentelemetry.trace import SpanKind, TracerProvider, get_tracer
from opentelemetry.trace.status import Status, StatusCanonicalCode from opentelemetry.trace.status import Status, StatusCanonicalCode
_UrlFilterT = typing.Optional[typing.Callable[[str], str]]
_SpanNameT = typing.Optional[
typing.Union[typing.Callable[[aiohttp.TraceRequestStartParams], str], str]
]
def url_path_span_name(params: aiohttp.TraceRequestStartParams) -> str: def url_path_span_name(params: aiohttp.TraceRequestStartParams) -> str:
"""Extract a span name from the request URL path. """Extract a span name from the request URL path.
@ -73,12 +102,9 @@ def url_path_span_name(params: aiohttp.TraceRequestStartParams) -> str:
def create_trace_config( def create_trace_config(
url_filter: typing.Optional[typing.Callable[[str], str]] = None, url_filter: _UrlFilterT = None,
span_name: typing.Optional[ span_name: _SpanNameT = None,
typing.Union[ tracer_provider: TracerProvider = None,
typing.Callable[[aiohttp.TraceRequestStartParams], str], str
]
] = None,
) -> aiohttp.TraceConfig: ) -> aiohttp.TraceConfig:
"""Create an aiohttp-compatible trace configuration. """Create an aiohttp-compatible trace configuration.
@ -104,6 +130,7 @@ def create_trace_config(
such as API keys or user personal information. such as API keys or user personal information.
:param str span_name: Override the default span name. :param str span_name: Override the default span name.
:param tracer_provider: optional TracerProvider from which to get a Tracer
:return: An object suitable for use with :py:class:`aiohttp.ClientSession`. :return: An object suitable for use with :py:class:`aiohttp.ClientSession`.
:rtype: :py:class:`aiohttp.TraceConfig` :rtype: :py:class:`aiohttp.TraceConfig`
@ -113,7 +140,7 @@ def create_trace_config(
# Explicitly specify the type for the `span_name` param and rtype to work # Explicitly specify the type for the `span_name` param and rtype to work
# around this issue. # around this issue.
tracer = trace.get_tracer_provider().get_tracer(__name__, __version__) tracer = get_tracer(__name__, __version__, tracer_provider)
def _end_trace(trace_config_ctx: types.SimpleNamespace): def _end_trace(trace_config_ctx: types.SimpleNamespace):
context_api.detach(trace_config_ctx.token) context_api.detach(trace_config_ctx.token)
@ -124,6 +151,10 @@ def create_trace_config(
trace_config_ctx: types.SimpleNamespace, trace_config_ctx: types.SimpleNamespace,
params: aiohttp.TraceRequestStartParams, params: aiohttp.TraceRequestStartParams,
): ):
if context_api.get_value("suppress_instrumentation"):
trace_config_ctx.span = None
return
http_method = params.method.upper() http_method = params.method.upper()
if trace_config_ctx.span_name is None: if trace_config_ctx.span_name is None:
request_span_name = "HTTP {}".format(http_method) request_span_name = "HTTP {}".format(http_method)
@ -158,6 +189,9 @@ def create_trace_config(
trace_config_ctx: types.SimpleNamespace, trace_config_ctx: types.SimpleNamespace,
params: aiohttp.TraceRequestEndParams, params: aiohttp.TraceRequestEndParams,
): ):
if trace_config_ctx.span is None:
return
if trace_config_ctx.span.is_recording(): if trace_config_ctx.span.is_recording():
trace_config_ctx.span.set_status( trace_config_ctx.span.set_status(
Status( Status(
@ -177,6 +211,9 @@ def create_trace_config(
trace_config_ctx: types.SimpleNamespace, trace_config_ctx: types.SimpleNamespace,
params: aiohttp.TraceRequestExceptionParams, params: aiohttp.TraceRequestExceptionParams,
): ):
if trace_config_ctx.span is None:
return
if trace_config_ctx.span.is_recording(): if trace_config_ctx.span.is_recording():
if isinstance( if isinstance(
params.exception, params.exception,
@ -193,6 +230,7 @@ def create_trace_config(
status = StatusCanonicalCode.UNAVAILABLE status = StatusCanonicalCode.UNAVAILABLE
trace_config_ctx.span.set_status(Status(status)) trace_config_ctx.span.set_status(Status(status))
trace_config_ctx.span.record_exception(params.exception)
_end_trace(trace_config_ctx) _end_trace(trace_config_ctx)
def _trace_config_ctx_factory(**kwargs): def _trace_config_ctx_factory(**kwargs):
@ -210,3 +248,84 @@ def create_trace_config(
trace_config.on_request_exception.append(on_request_exception) trace_config.on_request_exception.append(on_request_exception)
return trace_config return trace_config
def _instrument(
tracer_provider: TracerProvider = None,
url_filter: _UrlFilterT = None,
span_name: _SpanNameT = None,
):
"""Enables tracing of all ClientSessions
When a ClientSession gets created a TraceConfig is automatically added to
the session's trace_configs.
"""
# pylint:disable=unused-argument
def instrumented_init(wrapped, instance, args, kwargs):
if context_api.get_value("suppress_instrumentation"):
return wrapped(*args, **kwargs)
trace_configs = list(kwargs.get("trace_configs") or ())
trace_config = create_trace_config(
url_filter=url_filter,
span_name=span_name,
tracer_provider=tracer_provider,
)
trace_config.opentelemetry_aiohttp_instrumented = True
trace_configs.append(trace_config)
kwargs["trace_configs"] = trace_configs
return wrapped(*args, **kwargs)
wrapt.wrap_function_wrapper(
aiohttp.ClientSession, "__init__", instrumented_init
)
def _uninstrument():
"""Disables instrumenting for all newly created ClientSessions"""
unwrap(aiohttp.ClientSession, "__init__")
def _uninstrument_session(client_session: aiohttp.ClientSession):
"""Disables instrumentation for the given ClientSession"""
# pylint: disable=protected-access
trace_configs = client_session._trace_configs
client_session._trace_configs = [
trace_config
for trace_config in trace_configs
if not hasattr(trace_config, "opentelemetry_aiohttp_instrumented")
]
class AioHttpClientInstrumentor(BaseInstrumentor):
"""An instrumentor for aiohttp client sessions
See `BaseInstrumentor`
"""
def _instrument(self, **kwargs):
"""Instruments aiohttp ClientSession
Args:
**kwargs: Optional arguments
``tracer_provider``: a TracerProvider, defaults to global
``url_filter``: A callback to process the requested URL prior to adding
it as a span attribute. This can be useful to remove sensitive data
such as API keys or user personal information.
``span_name``: Override the default span name.
"""
_instrument(
tracer_provider=kwargs.get("tracer_provider"),
url_filter=kwargs.get("url_filter"),
span_name=kwargs.get("span_name"),
)
def _uninstrument(self, **kwargs):
_uninstrument()
@staticmethod
def uninstrument_session(client_session: aiohttp.ClientSession):
"""Disables instrumentation for the given session"""
_uninstrument_session(client_session)

View File

@ -15,6 +15,7 @@
import asyncio import asyncio
import contextlib import contextlib
import typing import typing
import unittest
import urllib.parse import urllib.parse
from http import HTTPStatus from http import HTTPStatus
from unittest import mock from unittest import mock
@ -22,15 +23,39 @@ from unittest import mock
import aiohttp import aiohttp
import aiohttp.test_utils import aiohttp.test_utils
import yarl import yarl
from pkg_resources import iter_entry_points
import opentelemetry.instrumentation.aiohttp_client from opentelemetry import context
from opentelemetry.instrumentation import aiohttp_client
from opentelemetry.instrumentation.aiohttp_client import (
AioHttpClientInstrumentor,
)
from opentelemetry.test.test_base import TestBase from opentelemetry.test.test_base import TestBase
from opentelemetry.trace.status import StatusCanonicalCode from opentelemetry.trace.status import StatusCanonicalCode
class TestAioHttpIntegration(TestBase): def run_with_test_server(
maxDiff = None runnable: typing.Callable, url: str, handler: typing.Callable
) -> typing.Tuple[str, int]:
async def do_request():
app = aiohttp.web.Application()
parsed_url = urllib.parse.urlparse(url)
app.add_routes([aiohttp.web.get(parsed_url.path, handler)])
app.add_routes([aiohttp.web.post(parsed_url.path, handler)])
app.add_routes([aiohttp.web.patch(parsed_url.path, handler)])
with contextlib.suppress(aiohttp.ClientError):
async with aiohttp.test_utils.TestServer(app) as server:
netloc = (server.host, server.port)
await server.start_server()
await runnable(server)
return netloc
loop = asyncio.get_event_loop()
return loop.run_until_complete(do_request())
class TestAioHttpIntegration(TestBase):
def assert_spans(self, spans): def assert_spans(self, spans):
self.assertEqual( self.assertEqual(
[ [
@ -54,9 +79,7 @@ class TestAioHttpIntegration(TestBase):
): ):
with self.subTest(url=url): with self.subTest(url=url):
params = aiohttp.TraceRequestStartParams("METHOD", url, {}) params = aiohttp.TraceRequestStartParams("METHOD", url, {})
actual = opentelemetry.instrumentation.aiohttp_client.url_path_span_name( actual = aiohttp_client.url_path_span_name(params)
params
)
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
self.assertIsInstance(actual, str) self.assertIsInstance(actual, str)
@ -71,33 +94,20 @@ class TestAioHttpIntegration(TestBase):
) -> typing.Tuple[str, int]: ) -> typing.Tuple[str, int]:
"""Helper to start an aiohttp test server and send an actual HTTP request to it.""" """Helper to start an aiohttp test server and send an actual HTTP request to it."""
async def do_request(): async def default_handler(request):
async def default_handler(request): assert "traceparent" in request.headers
assert "traceparent" in request.headers return aiohttp.web.Response(status=int(status_code))
return aiohttp.web.Response(status=int(status_code))
handler = request_handler or default_handler async def client_request(server: aiohttp.test_utils.TestServer):
async with aiohttp.test_utils.TestClient(
server, trace_configs=[trace_config]
) as client:
await client.request(
method, url, trace_request_ctx={}, **kwargs
)
app = aiohttp.web.Application() handler = request_handler or default_handler
parsed_url = urllib.parse.urlparse(url) return run_with_test_server(client_request, url, handler)
app.add_routes([aiohttp.web.get(parsed_url.path, handler)])
app.add_routes([aiohttp.web.post(parsed_url.path, handler)])
app.add_routes([aiohttp.web.patch(parsed_url.path, handler)])
with contextlib.suppress(aiohttp.ClientError):
async with aiohttp.test_utils.TestServer(app) as server:
netloc = (server.host, server.port)
async with aiohttp.test_utils.TestClient(
server, trace_configs=[trace_config]
) as client:
await client.start_server()
await client.request(
method, url, trace_request_ctx={}, **kwargs
)
return netloc
loop = asyncio.get_event_loop()
return loop.run_until_complete(do_request())
def test_status_codes(self): def test_status_codes(self):
for status_code, span_status in ( for status_code, span_status in (
@ -111,7 +121,7 @@ class TestAioHttpIntegration(TestBase):
): ):
with self.subTest(status_code=status_code): with self.subTest(status_code=status_code):
host, port = self._http_request( host, port = self._http_request(
trace_config=opentelemetry.instrumentation.aiohttp_client.create_trace_config(), trace_config=aiohttp_client.create_trace_config(),
url="/test-path?query=param#foobar", url="/test-path?query=param#foobar",
status_code=status_code, status_code=status_code,
) )
@ -144,7 +154,7 @@ class TestAioHttpIntegration(TestBase):
with mock.patch("opentelemetry.trace.get_tracer"): with mock.patch("opentelemetry.trace.get_tracer"):
# pylint: disable=W0612 # pylint: disable=W0612
host, port = self._http_request( host, port = self._http_request(
trace_config=opentelemetry.instrumentation.aiohttp_client.create_trace_config(), trace_config=aiohttp_client.create_trace_config(),
url="/test-path?query=param#foobar", url="/test-path?query=param#foobar",
) )
self.assertFalse(mock_span.is_recording()) self.assertFalse(mock_span.is_recording())
@ -166,7 +176,7 @@ class TestAioHttpIntegration(TestBase):
): ):
with self.subTest(span_name=span_name, method=method, path=path): with self.subTest(span_name=span_name, method=method, path=path):
host, port = self._http_request( host, port = self._http_request(
trace_config=opentelemetry.instrumentation.aiohttp_client.create_trace_config( trace_config=aiohttp_client.create_trace_config(
span_name=span_name span_name=span_name
), ),
method=method, method=method,
@ -199,7 +209,7 @@ class TestAioHttpIntegration(TestBase):
return str(url.with_query(None)) return str(url.with_query(None))
host, port = self._http_request( host, port = self._http_request(
trace_config=opentelemetry.instrumentation.aiohttp_client.create_trace_config( trace_config=aiohttp_client.create_trace_config(
url_filter=strip_query_params url_filter=strip_query_params
), ),
url="/some/path?query=param&other=param2", url="/some/path?query=param&other=param2",
@ -225,9 +235,7 @@ class TestAioHttpIntegration(TestBase):
) )
def test_connection_errors(self): def test_connection_errors(self):
trace_configs = [ trace_configs = [aiohttp_client.create_trace_config()]
opentelemetry.instrumentation.aiohttp_client.create_trace_config()
]
for url, expected_status in ( for url, expected_status in (
("http://this-is-unknown.local/", StatusCanonicalCode.UNKNOWN), ("http://this-is-unknown.local/", StatusCanonicalCode.UNKNOWN),
@ -237,7 +245,7 @@ class TestAioHttpIntegration(TestBase):
async def do_request(url): async def do_request(url):
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
trace_configs=trace_configs trace_configs=trace_configs,
) as session: ) as session:
async with session.get(url): async with session.get(url):
pass pass
@ -268,7 +276,7 @@ class TestAioHttpIntegration(TestBase):
return aiohttp.web.Response() return aiohttp.web.Response()
host, port = self._http_request( host, port = self._http_request(
trace_config=opentelemetry.instrumentation.aiohttp_client.create_trace_config(), trace_config=aiohttp_client.create_trace_config(),
url="/test_timeout", url="/test_timeout",
request_handler=request_handler, request_handler=request_handler,
timeout=aiohttp.ClientTimeout(sock_read=0.01), timeout=aiohttp.ClientTimeout(sock_read=0.01),
@ -298,7 +306,7 @@ class TestAioHttpIntegration(TestBase):
raise aiohttp.web.HTTPFound(location=location) raise aiohttp.web.HTTPFound(location=location)
host, port = self._http_request( host, port = self._http_request(
trace_config=opentelemetry.instrumentation.aiohttp_client.create_trace_config(), trace_config=aiohttp_client.create_trace_config(),
url="/test_too_many_redirects", url="/test_too_many_redirects",
request_handler=request_handler, request_handler=request_handler,
max_redirects=2, max_redirects=2,
@ -319,3 +327,177 @@ class TestAioHttpIntegration(TestBase):
) )
] ]
) )
class TestAioHttpClientInstrumentor(TestBase):
URL = "/test-path"
def setUp(self):
super().setUp()
AioHttpClientInstrumentor().instrument()
def tearDown(self):
super().tearDown()
AioHttpClientInstrumentor().uninstrument()
@staticmethod
# pylint:disable=unused-argument
async def default_handler(request):
return aiohttp.web.Response(status=int(200))
@staticmethod
def get_default_request(url: str = URL):
async def default_request(server: aiohttp.test_utils.TestServer):
async with aiohttp.test_utils.TestClient(server) as session:
await session.get(url)
return default_request
def assert_spans(self, num_spans: int):
finished_spans = self.memory_exporter.get_finished_spans()
self.assertEqual(num_spans, len(finished_spans))
if num_spans == 0:
return None
if num_spans == 1:
return finished_spans[0]
return finished_spans
def test_instrument(self):
host, port = run_with_test_server(
self.get_default_request(), self.URL, self.default_handler
)
span = self.assert_spans(1)
self.assertEqual("http", span.attributes["component"])
self.assertEqual("GET", span.attributes["http.method"])
self.assertEqual(
"http://{}:{}/test-path".format(host, port),
span.attributes["http.url"],
)
self.assertEqual(200, span.attributes["http.status_code"])
self.assertEqual("OK", span.attributes["http.status_text"])
def test_instrument_with_existing_trace_config(self):
trace_config = aiohttp.TraceConfig()
async def create_session(server: aiohttp.test_utils.TestServer):
async with aiohttp.test_utils.TestClient(
server, trace_configs=[trace_config]
) as client:
# pylint:disable=protected-access
trace_configs = client.session._trace_configs
self.assertEqual(2, len(trace_configs))
self.assertTrue(trace_config in trace_configs)
async with client as session:
await session.get(TestAioHttpClientInstrumentor.URL)
run_with_test_server(create_session, self.URL, self.default_handler)
self.assert_spans(1)
def test_uninstrument(self):
AioHttpClientInstrumentor().uninstrument()
run_with_test_server(
self.get_default_request(), self.URL, self.default_handler
)
self.assert_spans(0)
AioHttpClientInstrumentor().instrument()
run_with_test_server(
self.get_default_request(), self.URL, self.default_handler
)
self.assert_spans(1)
def test_uninstrument_session(self):
async def uninstrument_request(server: aiohttp.test_utils.TestServer):
client = aiohttp.test_utils.TestClient(server)
AioHttpClientInstrumentor().uninstrument_session(client.session)
async with client as session:
await session.get(self.URL)
run_with_test_server(
uninstrument_request, self.URL, self.default_handler
)
self.assert_spans(0)
run_with_test_server(
self.get_default_request(), self.URL, self.default_handler
)
self.assert_spans(1)
def test_suppress_instrumentation(self):
token = context.attach(
context.set_value("suppress_instrumentation", True)
)
try:
run_with_test_server(
self.get_default_request(), self.URL, self.default_handler
)
finally:
context.detach(token)
self.assert_spans(0)
@staticmethod
async def suppressed_request(server: aiohttp.test_utils.TestServer):
async with aiohttp.test_utils.TestClient(server) as client:
token = context.attach(
context.set_value("suppress_instrumentation", True)
)
await client.get(TestAioHttpClientInstrumentor.URL)
context.detach(token)
def test_suppress_instrumentation_after_creation(self):
run_with_test_server(
self.suppressed_request, self.URL, self.default_handler
)
self.assert_spans(0)
def test_suppress_instrumentation_with_server_exception(self):
# pylint:disable=unused-argument
async def raising_handler(request):
raise aiohttp.web.HTTPFound(location=self.URL)
run_with_test_server(
self.suppressed_request, self.URL, raising_handler
)
self.assert_spans(0)
def test_url_filter(self):
def strip_query_params(url: yarl.URL) -> str:
return str(url.with_query(None))
AioHttpClientInstrumentor().uninstrument()
AioHttpClientInstrumentor().instrument(url_filter=strip_query_params)
url = "/test-path?query=params"
host, port = run_with_test_server(
self.get_default_request(url), url, self.default_handler
)
span = self.assert_spans(1)
self.assertEqual(
"http://{}:{}/test-path".format(host, port),
span.attributes["http.url"],
)
def test_span_name(self):
def span_name_callback(params: aiohttp.TraceRequestStartParams) -> str:
return "{} - {}".format(params.method, params.url.path)
AioHttpClientInstrumentor().uninstrument()
AioHttpClientInstrumentor().instrument(span_name=span_name_callback)
url = "/test-path"
run_with_test_server(
self.get_default_request(url), url, self.default_handler
)
span = self.assert_spans(1)
self.assertEqual("GET - /test-path", span.name)
class TestLoadingAioHttpInstrumentor(unittest.TestCase):
def test_loading_instrumentor(self):
entry_points = iter_entry_points(
"opentelemetry_instrumentor", "aiohttp-client"
)
instrumentor = next(entry_points).load()()
self.assertIsInstance(instrumentor, AioHttpClientInstrumentor)