Update instrumentations to use tracer_provider for creating tracer if given, otherwise use global tracer provider (#402)

This commit is contained in:
Srikanth Chekuri
2021-04-28 21:06:37 +05:30
committed by GitHub
parent bdbc249ff0
commit 3ec77360cb
33 changed files with 408 additions and 95 deletions

View File

@ -25,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#387](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/387))
- Update redis instrumentation to follow semantic conventions
([#403](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/403))
- Update instrumentations to use tracer_provider for creating tracer if given, otherwise use global tracer provider
([#402](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/402))
- `opentelemetry-instrumentation-wsgi` Replaced `name_callback` with `request_hook`
and `response_hook` callbacks.
([#424](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/424))

View File

@ -83,17 +83,20 @@ class AiopgInstrumentor(BaseInstrumentor):
tracer_provider=tracer_provider,
)
# pylint:disable=no-self-use
def _uninstrument(self, **kwargs):
""""Disable aiopg instrumentation"""
wrappers.unwrap_connect()
wrappers.unwrap_create_pool()
# pylint:disable=no-self-use
def instrument_connection(self, connection):
def instrument_connection(self, connection, tracer_provider=None):
"""Enable instrumentation in a aiopg connection.
Args:
connection: The connection to instrument.
tracer_provider: The optional tracer provider to use. If omitted
the current globally configured one is used.
Returns:
An instrumented connection.
@ -103,6 +106,8 @@ class AiopgInstrumentor(BaseInstrumentor):
connection,
self._DATABASE_SYSTEM,
self._CONNECTION_ATTRIBUTES,
version=__version__,
tracer_provider=tracer_provider,
)
def uninstrument_connection(self, connection):

View File

@ -114,7 +114,7 @@ class AsyncCursorTracer(CursorTracer):
else self._db_api_integration.name
)
with self._db_api_integration.get_tracer().start_as_current_span(
with self._db_api_integration._tracer.start_as_current_span(
name, kind=SpanKind.CLIENT
) as span:
self._populate_span(span, cursor, *args)

View File

@ -204,6 +204,32 @@ class TestAiopgInstrumentor(TestBase):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
def test_custom_tracer_provider_instrument_connection(self):
resource = resources.Resource.create(
{"service.name": "db-test-service"}
)
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
cnx = async_call(aiopg.connect(database="test"))
cnx = AiopgInstrumentor().instrument_connection(
cnx, tracer_provider=tracer_provider
)
cursor = async_call(cnx.cursor())
query = "SELECT * FROM test"
async_call(cursor.execute(query))
spans_list = exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertEqual(
span.resource.attributes["service.name"], "db-test-service"
)
self.assertIs(span.resource, resource)
def test_uninstrument_connection(self):
AiopgInstrumentor().instrument()
cnx = async_call(aiopg.connect(database="test"))

View File

@ -167,11 +167,19 @@ class OpenTelemetryMiddleware:
and a tuple, representing the desired span name and a
dictionary with any additional span attributes to set.
Optional: Defaults to get_default_span_details.
tracer_provider: The optional tracer provider to use. If omitted
the current globally configured one is used.
"""
def __init__(self, app, excluded_urls=None, span_details_callback=None):
def __init__(
self,
app,
excluded_urls=None,
span_details_callback=None,
tracer_provider=None,
):
self.app = guarantee_single_callable(app)
self.tracer = trace.get_tracer(__name__, __version__)
self.tracer = trace.get_tracer(__name__, __version__, tracer_provider)
self.span_details_callback = (
span_details_callback or get_default_span_details
)

View File

@ -18,11 +18,13 @@ import unittest.mock as mock
import opentelemetry.instrumentation.asgi as otel_asgi
from opentelemetry import trace as trace_api
from opentelemetry.sdk import resources
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.asgitestutil import (
AsgiTestBase,
setup_testing_defaults,
)
from opentelemetry.test.test_base import TestBase
async def http_app(scope, receive, send):
@ -211,6 +213,22 @@ class TestAsgiApplication(AsgiTestBase):
outputs = self.get_all_output()
self.validate_outputs(outputs, modifiers=[update_expected_span_name])
def test_custom_tracer_provider_otel_asgi(self):
resource = resources.Resource.create({"service-test-key": "value"})
result = TestBase.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
app = otel_asgi.OpenTelemetryMiddleware(
simple_asgi, tracer_provider=tracer_provider
)
self.seed_app(app)
self.send_default_request()
span_list = exporter.get_finished_spans()
for span in span_list:
self.assertEqual(
span.resource.attributes["service-test-key"], "value"
)
def test_behavior_with_scope_server_as_none(self):
"""Test that middleware is ok when server is none in scope."""

View File

@ -50,8 +50,6 @@ from opentelemetry.semconv.trace import (
from opentelemetry.trace import SpanKind
from opentelemetry.trace.status import Status, StatusCode
_APPLIED = "_opentelemetry_tracer"
def _hydrate_span_from_args(connection, query, parameters) -> dict:
"""Get network and database attributes from connection."""
@ -98,16 +96,11 @@ class AsyncPGInstrumentor(BaseInstrumentor):
def __init__(self, capture_parameters=False):
super().__init__()
self.capture_parameters = capture_parameters
self._tracer = None
def _instrument(self, **kwargs):
tracer_provider = kwargs.get(
"tracer_provider", trace.get_tracer_provider()
)
setattr(
asyncpg,
_APPLIED,
tracer_provider.get_tracer("asyncpg", __version__),
)
tracer_provider = kwargs.get("tracer_provider")
self._tracer = trace.get_tracer(__name__, __version__, tracer_provider)
for method in [
"Connection.execute",
@ -121,7 +114,6 @@ class AsyncPGInstrumentor(BaseInstrumentor):
)
def _uninstrument(self, **__):
delattr(asyncpg, _APPLIED)
for method in [
"execute",
"executemany",
@ -132,13 +124,14 @@ class AsyncPGInstrumentor(BaseInstrumentor):
unwrap(asyncpg.Connection, method)
async def _do_execute(self, func, instance, args, kwargs):
tracer = getattr(asyncpg, _APPLIED)
exception = None
params = getattr(instance, "_params", {})
name = args[0] if args[0] else params.get("database", "postgresql")
with tracer.start_as_current_span(name, kind=SpanKind.CLIENT) as span:
with self._tracer.start_as_current_span(
name, kind=SpanKind.CLIENT
) as span:
if span.is_recording():
span_attributes = _hydrate_span_from_args(
instance,

View File

@ -1,4 +1,3 @@
import asyncpg
from asyncpg import Connection
from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
@ -6,12 +5,6 @@ from opentelemetry.test.test_base import TestBase
class TestAsyncPGInstrumentation(TestBase):
def test_instrumentation_flags(self):
AsyncPGInstrumentor().instrument()
self.assertTrue(hasattr(asyncpg, "_opentelemetry_tracer"))
AsyncPGInstrumentor().uninstrument()
self.assertFalse(hasattr(asyncpg, "_opentelemetry_tracer"))
def test_duplicated_instrumentation(self):
AsyncPGInstrumentor().instrument()
AsyncPGInstrumentor().instrument()

View File

@ -228,7 +228,11 @@ class DatabaseApiIntegration:
}
self._name = name
self._version = version
self._tracer_provider = tracer_provider
self._tracer = get_tracer(
self._name,
instrumenting_library_version=self._version,
tracer_provider=tracer_provider,
)
self.capture_parameters = capture_parameters
self.database_system = database_system
self.connection_props = {}
@ -236,13 +240,6 @@ class DatabaseApiIntegration:
self.name = ""
self.database = ""
def get_tracer(self):
return get_tracer(
self._name,
instrumenting_library_version=self._version,
tracer_provider=self._tracer_provider,
)
def wrapped_connection(
self,
connect_method: typing.Callable[..., typing.Any],
@ -370,7 +367,7 @@ class CursorTracer:
else self._db_api_integration.name
)
with self._db_api_integration.get_tracer().start_as_current_span(
with self._db_api_integration._tracer.start_as_current_span(
name, kind=SpanKind.CLIENT
) as span:
self._populate_span(span, cursor, *args)

View File

@ -18,6 +18,7 @@ from unittest import mock
from opentelemetry import trace as trace_api
from opentelemetry.instrumentation import dbapi
from opentelemetry.sdk import resources
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
@ -41,7 +42,7 @@ class TestDBApiIntegration(TestBase):
"user": "user",
}
db_integration = dbapi.DatabaseApiIntegration(
self.tracer, "testcomponent", connection_attributes
"testname", "testcomponent", connection_attributes
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, connection_props
@ -73,7 +74,7 @@ class TestDBApiIntegration(TestBase):
def test_span_name(self):
db_integration = dbapi.DatabaseApiIntegration(
self.tracer, "testcomponent", {}
"testname", "testcomponent", {}
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, {}
@ -106,7 +107,7 @@ class TestDBApiIntegration(TestBase):
"user": "user",
}
db_integration = dbapi.DatabaseApiIntegration(
self.tracer,
"testname",
"testcomponent",
connection_attributes,
capture_parameters=True,
@ -155,12 +156,10 @@ class TestDBApiIntegration(TestBase):
"host": "server_host",
"user": "user",
}
mock_tracer = mock.Mock()
mock_span = mock.Mock()
mock_span.is_recording.return_value = False
mock_tracer.start_span.return_value = mock_span
db_integration = dbapi.DatabaseApiIntegration(
mock_tracer, "testcomponent", connection_attributes
"testname", "testcomponent", connection_attributes
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, connection_props
@ -192,9 +191,30 @@ class TestDBApiIntegration(TestBase):
self.assertIs(span.status.status_code, trace_api.StatusCode.ERROR)
self.assertEqual(span.status.description, "Exception: Test Exception")
def test_custom_tracer_provider_dbapi(self):
resource = resources.Resource.create({"db-resource-key": "value"})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
db_integration = dbapi.DatabaseApiIntegration(
self.tracer, "testcomponent", tracer_provider=tracer_provider
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, {}
)
cursor = mock_connection.cursor()
with self.assertRaises(Exception):
cursor.execute("Test query", throw_exception=True)
spans_list = exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertEqual(span.resource.attributes["db-resource-key"], "value")
self.assertIs(span.status.status_code, trace_api.StatusCode.ERROR)
def test_executemany(self):
db_integration = dbapi.DatabaseApiIntegration(
self.tracer, "testcomponent"
"testname", "testcomponent"
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, {}
@ -210,7 +230,7 @@ class TestDBApiIntegration(TestBase):
def test_callproc(self):
db_integration = dbapi.DatabaseApiIntegration(
self.tracer, "testcomponent"
"testname", "testcomponent"
)
mock_connection = db_integration.wrapped_connection(
mock_connect, {}, {}

View File

@ -84,6 +84,7 @@ from opentelemetry.instrumentation.django.environment_variables import (
from opentelemetry.instrumentation.django.middleware import _DjangoMiddleware
from opentelemetry.instrumentation.django.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.trace import get_tracer
_logger = getLogger(__name__)
@ -105,6 +106,13 @@ class DjangoInstrumentor(BaseInstrumentor):
if environ.get(OTEL_PYTHON_DJANGO_INSTRUMENT) == "False":
return
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(
__name__, __version__, tracer_provider=tracer_provider,
)
_DjangoMiddleware._tracer = tracer
_DjangoMiddleware._otel_request_hook = kwargs.pop("request_hook", None)
_DjangoMiddleware._otel_response_hook = kwargs.pop(
"response_hook", None

View File

@ -31,7 +31,7 @@ from opentelemetry.instrumentation.wsgi import (
)
from opentelemetry.propagate import extract
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span, SpanKind, get_tracer, use_span
from opentelemetry.trace import Span, SpanKind, use_span
from opentelemetry.util.http import get_excluded_urls, get_traced_request_attrs
try:
@ -82,6 +82,7 @@ class _DjangoMiddleware(MiddlewareMixin):
_traced_request_attrs = get_traced_request_attrs("DJANGO")
_excluded_urls = get_excluded_urls("DJANGO")
_tracer = None
_otel_request_hook: Callable[[Span, HttpRequest], None] = None
_otel_response_hook: Callable[
@ -125,9 +126,7 @@ class _DjangoMiddleware(MiddlewareMixin):
token = attach(extract(request_meta, getter=wsgi_getter))
tracer = get_tracer(__name__, __version__)
span = tracer.start_span(
span = self._tracer.start_span(
self._get_span_name(request),
kind=SpanKind.SERVER,
start_time=request_meta.get(

View File

@ -30,6 +30,7 @@ from opentelemetry.instrumentation.propagators import (
TraceResponsePropagator,
set_global_response_propagator,
)
from opentelemetry.sdk import resources
from opentelemetry.sdk.trace import Span
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
@ -354,3 +355,37 @@ class TestMiddleware(TestBase, WsgiTestBase):
),
)
self.memory_exporter.clear()
class TestMiddlewareWithTracerProvider(TestBase, WsgiTestBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
def setUp(self):
super().setUp()
setup_test_environment()
resource = resources.Resource.create(
{"resource-key": "resource-value"}
)
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
self.exporter = exporter
_django_instrumentor.instrument(tracer_provider=tracer_provider)
def tearDown(self):
super().tearDown()
teardown_test_environment()
_django_instrumentor.uninstrument()
def test_tracer_provider_traced(self):
Client().post("/traced/")
spans = self.exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(
span.resource.attributes["resource-key"], "resource-value"
)

View File

@ -146,10 +146,12 @@ class _InstrumentedFalconAPI(falcon.API):
def __init__(self, *args, **kwargs):
# inject trace middleware
middlewares = kwargs.pop("middleware", [])
tracer_provider = kwargs.pop("tracer_provider", None)
if not isinstance(middlewares, (list, tuple)):
middlewares = [middlewares]
self._tracer = trace.get_tracer(__name__, __version__)
self._tracer = trace.get_tracer(__name__, __version__, tracer_provider)
trace_middleware = _TraceMiddleware(
self._tracer,
kwargs.pop("traced_request_attributes", None),

View File

@ -22,6 +22,7 @@ from opentelemetry.instrumentation.propagators import (
get_global_response_propagator,
set_global_response_propagator,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import StatusCode, format_span_id, format_trace_id
@ -239,6 +240,36 @@ class TestFalconInstrumentation(TestFalconBase):
self.assertFalse(mock_span.set_status.called)
class TestFalconInstrumentationWithTracerProvider(TestBase):
def setUp(self):
super().setUp()
resource = Resource.create({"resource-key": "resource-value"})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
self.exporter = exporter
FalconInstrumentor().instrument(tracer_provider=tracer_provider)
self.app = make_app()
def client(self):
return testing.TestClient(self.app)
def tearDown(self):
super().tearDown()
with self.disable_logging():
FalconInstrumentor().uninstrument()
def test_traced_request(self):
self.client().simulate_request(method="GET", path="/hello")
spans = self.exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(
span.resource.attributes["resource-key"], "resource-value"
)
self.exporter.clear()
class TestFalconInstrumentationHooks(TestFalconBase):
# pylint: disable=no-self-use
def request_hook(self, span, req):

View File

@ -32,7 +32,7 @@ class FastAPIInstrumentor(BaseInstrumentor):
_original_fastapi = None
@staticmethod
def instrument_app(app: fastapi.FastAPI):
def instrument_app(app: fastapi.FastAPI, tracer_provider=None):
"""Instrument an uninstrumented FastAPI application.
"""
if not getattr(app, "is_instrumented_by_opentelemetry", False):
@ -40,11 +40,13 @@ class FastAPIInstrumentor(BaseInstrumentor):
OpenTelemetryMiddleware,
excluded_urls=_excluded_urls,
span_details_callback=_get_route_details,
tracer_provider=tracer_provider,
)
app.is_instrumented_by_opentelemetry = True
def _instrument(self, **kwargs):
self._original_fastapi = fastapi.FastAPI
_InstrumentedFastAPI._tracer_provider = kwargs.get("tracer_provider")
fastapi.FastAPI = _InstrumentedFastAPI
def _uninstrument(self, **kwargs):
@ -52,12 +54,15 @@ class FastAPIInstrumentor(BaseInstrumentor):
class _InstrumentedFastAPI(fastapi.FastAPI):
_tracer_provider = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_middleware(
OpenTelemetryMiddleware,
excluded_urls=_excluded_urls,
span_details_callback=_get_route_details,
tracer_provider=_InstrumentedFastAPI._tracer_provider,
)

View File

@ -19,6 +19,7 @@ import fastapi
from fastapi.testclient import TestClient
import opentelemetry.instrumentation.fastapi as otel_fastapi
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
from opentelemetry.util.http import get_excluded_urls
@ -115,9 +116,22 @@ class TestAutoInstrumentation(TestFastAPIManualInstrumentation):
def _create_app(self):
# instrumentation is handled by the instrument call
self._instrumentor.instrument()
resource = Resource.create({"key1": "value1", "key2": "value2"})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
self.memory_exporter = exporter
self._instrumentor.instrument(tracer_provider=tracer_provider)
return self._create_fastapi_app()
def test_request(self):
self._client.get("/foobar")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
for span in spans:
self.assertEqual(span.resource.attributes["key1"], "value1")
self.assertEqual(span.resource.attributes["key2"], "value2")
def tearDown(self):
self._instrumentor.uninstrument()
super().tearDown()

View File

@ -120,7 +120,7 @@ def _rewrapped_app(wsgi_app):
return _wrapped_app
def _wrapped_before_request(name_callback):
def _wrapped_before_request(name_callback, tracer):
def _before_request():
if _excluded_urls.url_disabled(flask.request.url):
return
@ -131,8 +131,6 @@ def _wrapped_before_request(name_callback):
extract(flask_request_environ, getter=otel_wsgi.wsgi_getter)
)
tracer = trace.get_tracer(__name__, __version__)
span = tracer.start_span(
span_name,
kind=trace.SpanKind.SERVER,
@ -184,6 +182,7 @@ def _teardown_request(exc):
class _InstrumentedFlask(flask.Flask):
name_callback = get_default_span_name
_tracer_provider = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -191,8 +190,12 @@ class _InstrumentedFlask(flask.Flask):
self._original_wsgi_ = self.wsgi_app
self.wsgi_app = _rewrapped_app(self.wsgi_app)
tracer = trace.get_tracer(
__name__, __version__, _InstrumentedFlask._tracer_provider
)
_before_request = _wrapped_before_request(
_InstrumentedFlask.name_callback
_InstrumentedFlask.name_callback, tracer,
)
self._before_request = _before_request
self.before_request(_before_request)
@ -209,12 +212,14 @@ class FlaskInstrumentor(BaseInstrumentor):
def _instrument(self, **kwargs):
self._original_flask = flask.Flask
name_callback = kwargs.get("name_callback")
tracer_provider = kwargs.get("tracer_provider")
if callable(name_callback):
_InstrumentedFlask.name_callback = name_callback
_InstrumentedFlask._tracer_provider = tracer_provider
flask.Flask = _InstrumentedFlask
def instrument_app(
self, app, name_callback=get_default_span_name
self, app, name_callback=get_default_span_name, tracer_provider=None
): # pylint: disable=no-self-use
if not hasattr(app, "_is_instrumented"):
app._is_instrumented = False
@ -223,7 +228,9 @@ class FlaskInstrumentor(BaseInstrumentor):
app._original_wsgi_app = app.wsgi_app
app.wsgi_app = _rewrapped_app(app.wsgi_app)
_before_request = _wrapped_before_request(name_callback)
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
_before_request = _wrapped_before_request(name_callback, tracer)
app._before_request = _before_request
app.before_request(_before_request)
app.teardown_request(_teardown_request)

View File

@ -23,6 +23,7 @@ from opentelemetry.instrumentation.propagators import (
get_global_response_propagator,
set_global_response_propagator,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
from opentelemetry.test.wsgitestutil import WsgiTestBase
@ -277,3 +278,69 @@ class TestProgrammaticCustomSpanNameCallbackWithoutApp(
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 1)
self.assertEqual(span_list[0].name, "instrument-without-app")
class TestProgrammaticCustomTracerProvider(
InstrumentationTest, TestBase, WsgiTestBase
):
def setUp(self):
super().setUp()
resource = Resource.create({"service.name": "flask-api"})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
self.memory_exporter = exporter
self.app = Flask(__name__)
FlaskInstrumentor().instrument_app(
self.app, tracer_provider=tracer_provider
)
self._common_initialization()
def tearDown(self):
super().tearDown()
with self.disable_logging():
FlaskInstrumentor().uninstrument_app(self.app)
def test_custom_span_name(self):
self.client.get("/hello/123")
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 1)
self.assertEqual(
span_list[0].resource.attributes["service.name"], "flask-api"
)
class TestProgrammaticCustomTracerProviderWithoutApp(
InstrumentationTest, TestBase, WsgiTestBase
):
def setUp(self):
super().setUp()
resource = Resource.create({"service.name": "flask-api-no-app"})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
self.memory_exporter = exporter
FlaskInstrumentor().instrument(tracer_provider=tracer_provider)
# pylint: disable=import-outside-toplevel,reimported,redefined-outer-name
from flask import Flask
self.app = Flask(__name__)
self._common_initialization()
def tearDown(self):
super().tearDown()
with self.disable_logging():
FlaskInstrumentor().uninstrument()
def test_custom_span_name(self):
self.client.get("/hello/123")
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 1)
self.assertEqual(
span_list[0].resource.attributes["service.name"],
"flask-api-no-app",
)

View File

@ -147,13 +147,18 @@ class GrpcInstrumentorServer(BaseInstrumentor):
def _instrument(self, **kwargs):
self._original_func = grpc.server
tracer_provider = kwargs.get("tracer_provider")
def server(*args, **kwargs):
if "interceptors" in kwargs:
# add our interceptor as the first
kwargs["interceptors"].insert(0, server_interceptor())
kwargs["interceptors"].insert(
0, server_interceptor(tracer_provider=tracer_provider)
)
else:
kwargs["interceptors"] = [server_interceptor()]
kwargs["interceptors"] = [
server_interceptor(tracer_provider=tracer_provider)
]
return self._original_func(*args, **kwargs)
grpc.server = server

View File

@ -77,22 +77,24 @@ class MySQLInstrumentor(BaseInstrumentor):
dbapi.unwrap_connect(mysql.connector, "connect")
# pylint:disable=no-self-use
def instrument_connection(self, connection):
def instrument_connection(self, connection, tracer_provider=None):
"""Enable instrumentation in a MySQL connection.
Args:
connection: The connection to instrument.
tracer_provider: The optional tracer provider to use. If omitted
the current globally configured one is used.
Returns:
An instrumented connection.
"""
tracer = get_tracer(__name__, __version__)
return dbapi.instrument_connection(
tracer,
__name__,
connection,
self._DATABASE_SYSTEM,
self._CONNECTION_ATTRIBUTES,
version=__version__,
tracer_provider=tracer_provider,
)
def uninstrument_connection(self, connection):

View File

@ -86,11 +86,15 @@ class Psycopg2Instrumentor(BaseInstrumentor):
dbapi.unwrap_connect(psycopg2, "connect")
# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
def instrument_connection(self, connection): # pylint: disable=no-self-use
def instrument_connection(
self, connection, tracer_provider=None
): # pylint: disable=no-self-use
setattr(
connection, _OTEL_CURSOR_FACTORY_KEY, connection.cursor_factory
)
connection.cursor_factory = _new_cursor_factory()
connection.cursor_factory = _new_cursor_factory(
tracer_provider=tracer_provider
)
return connection
# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
@ -146,13 +150,14 @@ class CursorTracer(dbapi.CursorTracer):
return statement
def _new_cursor_factory(db_api=None, base_factory=None):
def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
if not db_api:
db_api = DatabaseApiIntegration(
__name__,
Psycopg2Instrumentor._DATABASE_SYSTEM,
connection_attributes=Psycopg2Instrumentor._CONNECTION_ATTRIBUTES,
version=__version__,
tracer_provider=tracer_provider,
)
base_factory = base_factory or pg_cursor

View File

@ -78,11 +78,13 @@ class PyMySQLInstrumentor(BaseInstrumentor):
dbapi.unwrap_connect(pymysql, "connect")
# pylint:disable=no-self-use
def instrument_connection(self, connection):
def instrument_connection(self, connection, tracer_provider=None):
"""Enable instrumentation in a PyMySQL connection.
Args:
connection: The connection to instrument.
tracer_provider: The optional tracer provider to use. If omitted
the current globally configured one is used.
Returns:
An instrumented connection.
@ -94,6 +96,7 @@ class PyMySQLInstrumentor(BaseInstrumentor):
self._DATABASE_SYSTEM,
self._CONNECTION_ATTRIBUTES,
version=__version__,
tracer_provider=tracer_provider,
)
def uninstrument_connection(self, connection):

View File

@ -111,13 +111,13 @@ class RedisInstrumentor(BaseInstrumentor):
"""
def _instrument(self, **kwargs):
tracer_provider = kwargs.get(
"tracer_provider", trace.get_tracer_provider()
)
tracer_provider = kwargs.get("tracer_provider")
setattr(
redis,
"_opentelemetry_tracer",
tracer_provider.get_tracer(_DEFAULT_SERVICE, __version__),
trace.get_tracer(
__name__, __version__, tracer_provider=tracer_provider,
),
)
if redis.VERSION < (3, 0, 0):

View File

@ -56,7 +56,7 @@ _SUPPRESS_HTTP_INSTRUMENTATION_KEY = "suppress_http_instrumentation"
# pylint: disable=unused-argument
# pylint: disable=R0915
def _instrument(tracer_provider=None, span_callback=None, name_callback=None):
def _instrument(tracer, span_callback=None, name_callback=None):
"""Enables tracing of all requests calls that go through
:code:`requests.session.Session.request` (this includes
:code:`requests.get`, etc.)."""
@ -126,9 +126,9 @@ def _instrument(tracer_provider=None, span_callback=None, name_callback=None):
labels[SpanAttributes.HTTP_METHOD] = method
labels[SpanAttributes.HTTP_URL] = url
with get_tracer(
__name__, __version__, tracer_provider
).start_as_current_span(span_name, kind=SpanKind.CLIENT) as span:
with tracer.start_as_current_span(
span_name, kind=SpanKind.CLIENT
) as span:
exception = None
if span.is_recording():
span.set_attribute(SpanAttributes.HTTP_METHOD, method)
@ -224,8 +224,10 @@ class RequestsInstrumentor(BaseInstrumentor):
outgoing HTTP request based on the method and url.
Optional: Defaults to get_default_span_name.
"""
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, __version__, tracer_provider)
_instrument(
tracer_provider=kwargs.get("tracer_provider"),
tracer,
span_callback=kwargs.get("span_callback"),
name_callback=kwargs.get("name_callback"),
)

View File

@ -38,10 +38,10 @@ def _normalize_vendor(vendor):
def _get_tracer(engine, tracer_provider=None):
if tracer_provider is None:
tracer_provider = trace.get_tracer_provider()
return tracer_provider.get_tracer(
_normalize_vendor(engine.name), __version__
return trace.get_tracer(
_normalize_vendor(engine.name),
__version__,
tracer_provider=tracer_provider,
)

View File

@ -74,22 +74,25 @@ class SQLite3Instrumentor(BaseInstrumentor):
dbapi.unwrap_connect(sqlite3, "connect")
# pylint:disable=no-self-use
def instrument_connection(self, connection):
def instrument_connection(self, connection, tracer_provider=None):
"""Enable instrumentation in a SQLite connection.
Args:
connection: The connection to instrument.
tracer_provider: The optional tracer provider to use. If omitted
the current globally configured one is used.
Returns:
An instrumented connection.
"""
tracer = get_tracer(__name__, __version__)
return dbapi.instrument_connection(
tracer,
__name__,
connection,
self._DATABASE_SYSTEM,
self._CONNECTION_ATTRIBUTES,
version=__version__,
tracer_provider=tracer_provider,
)
def uninstrument_connection(self, connection):

View File

@ -32,7 +32,7 @@ class StarletteInstrumentor(BaseInstrumentor):
_original_starlette = None
@staticmethod
def instrument_app(app: applications.Starlette):
def instrument_app(app: applications.Starlette, tracer_provider=None):
"""Instrument an uninstrumented Starlette application.
"""
if not getattr(app, "is_instrumented_by_opentelemetry", False):
@ -40,11 +40,13 @@ class StarletteInstrumentor(BaseInstrumentor):
OpenTelemetryMiddleware,
excluded_urls=_excluded_urls,
span_details_callback=_get_route_details,
tracer_provider=tracer_provider,
)
app.is_instrumented_by_opentelemetry = True
def _instrument(self, **kwargs):
self._original_starlette = applications.Starlette
_InstrumentedStarlette._tracer_provider = kwargs.get("tracer_provider")
applications.Starlette = _InstrumentedStarlette
def _uninstrument(self, **kwargs):
@ -52,12 +54,15 @@ class StarletteInstrumentor(BaseInstrumentor):
class _InstrumentedStarlette(applications.Starlette):
_tracer_provider = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_middleware(
OpenTelemetryMiddleware,
excluded_urls=_excluded_urls,
span_details_callback=_get_route_details,
tracer_provider=_InstrumentedStarlette._tracer_provider,
)

View File

@ -21,6 +21,7 @@ from starlette.routing import Route
from starlette.testclient import TestClient
import opentelemetry.instrumentation.starlette as otel_starlette
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
from opentelemetry.util.http import get_excluded_urls
@ -109,13 +110,27 @@ class TestAutoInstrumentation(TestStarletteManualInstrumentation):
def _create_app(self):
# instrumentation is handled by the instrument call
self._instrumentor.instrument()
resource = Resource.create({"key1": "value1", "key2": "value2"})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
self.memory_exporter = exporter
self._instrumentor.instrument(tracer_provider=tracer_provider)
return self._create_starlette_app()
def tearDown(self):
self._instrumentor.uninstrument()
super().tearDown()
def test_request(self):
self._client.get("/foobar")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
for span in spans:
self.assertEqual(span.resource.attributes["key1"], "value1")
self.assertEqual(span.resource.attributes["key2"], "value2")
class TestAutoInstrumentationLogic(unittest.TestCase):
def test_instrumentation(self):

View File

@ -75,9 +75,10 @@ class URLLibInstrumentor(BaseInstrumentor):
outgoing HTTP request based on the method and url.
Optional: Defaults to get_default_span_name.
"""
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, __version__, tracer_provider)
_instrument(
tracer_provider=kwargs.get("tracer_provider"),
tracer,
span_callback=kwargs.get("span_callback"),
name_callback=kwargs.get("name_callback"),
)
@ -97,7 +98,7 @@ def get_default_span_name(method):
return "HTTP {}".format(method).strip()
def _instrument(tracer_provider=None, span_callback=None, name_callback=None):
def _instrument(tracer, span_callback=None, name_callback=None):
"""Enables tracing of all requests calls that go through
:code:`urllib.Client._make_request`"""
@ -143,9 +144,9 @@ def _instrument(tracer_provider=None, span_callback=None, name_callback=None):
SpanAttributes.HTTP_URL: url,
}
with get_tracer(
__name__, __version__, tracer_provider
).start_as_current_span(span_name, kind=SpanKind.CLIENT) as span:
with tracer.start_as_current_span(
span_name, kind=SpanKind.CLIENT
) as span:
exception = None
if span.is_recording():
span.set_attribute(SpanAttributes.HTTP_METHOD, method)

View File

@ -86,8 +86,10 @@ class URLLib3Instrumentor(BaseInstrumentor):
``url_filter``: A callback to process the requested URL prior
to adding it as a span attribute.
"""
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, __version__, tracer_provider)
_instrument(
tracer_provider=kwargs.get("tracer_provider"),
tracer,
span_name_or_callback=kwargs.get("span_name"),
url_filter=kwargs.get("url_filter"),
)
@ -97,7 +99,7 @@ class URLLib3Instrumentor(BaseInstrumentor):
def _instrument(
tracer_provider: TracerProvider = None,
tracer,
span_name_or_callback: _SpanNameT = None,
url_filter: _UrlFilterT = None,
):
@ -115,9 +117,7 @@ def _instrument(
SpanAttributes.HTTP_URL: url,
}
with get_tracer(
__name__, __version__, tracer_provider
).start_as_current_span(
with tracer.start_as_current_span(
span_name, kind=SpanKind.CLIENT, attributes=span_attributes
) as span:
inject(headers)

View File

@ -194,11 +194,15 @@ class OpenTelemetryMiddleware:
response_hook: Optional callback which is called with the server span,
WSGI environ, status_code and response_headers for every
incoming request.
tracer_provider: Optional tracer provider to use. If omitted the current
globally configured one is used.
"""
def __init__(self, wsgi, request_hook=None, response_hook=None):
def __init__(
self, wsgi, request_hook=None, response_hook=None, tracer_provider=None
):
self.wsgi = wsgi
self.tracer = trace.get_tracer(__name__, __version__)
self.tracer = trace.get_tracer(__name__, __version__, tracer_provider)
self.request_hook = request_hook
self.response_hook = response_hook

View File

@ -20,7 +20,9 @@ from urllib.parse import urlsplit
import opentelemetry.instrumentation.wsgi as otel_wsgi
from opentelemetry import trace as trace_api
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
from opentelemetry.test.wsgitestutil import WsgiTestBase
from opentelemetry.trace import StatusCode
@ -363,5 +365,41 @@ class TestWsgiAttributes(unittest.TestCase):
self.span.set_attribute.assert_has_calls(expected, any_order=True)
class TestWsgiMiddlewareWithTracerProvider(WsgiTestBase):
def validate_response(
self,
response,
exporter,
error=None,
span_name="HTTP GET",
http_method="GET",
):
while True:
try:
value = next(response)
self.assertEqual(value, b"*")
except StopIteration:
break
span_list = exporter.get_finished_spans()
self.assertEqual(len(span_list), 1)
self.assertEqual(span_list[0].name, span_name)
self.assertEqual(span_list[0].kind, trace_api.SpanKind.SERVER)
self.assertEqual(
span_list[0].resource.attributes["service-key"], "service-value"
)
def test_basic_wsgi_call(self):
resource = Resource.create({"service-key": "service-value"})
result = TestBase.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
app = otel_wsgi.OpenTelemetryMiddleware(
simple_wsgi, tracer_provider=tracer_provider
)
response = app(self.environ, self.start_response)
self.validate_response(response, exporter)
if __name__ == "__main__":
unittest.main()