Respect provided tracer provider when instrumenting SQLAlchemy (#728)

* respect provided tracer provider when instrumenting sqlalchemy

This change updates the SQLALchemyInstrumentor to respect the tracer
provider that is passed in through the kwargs when patching the
`create_engine` functionality provided by SQLAlchemy. Previously, it
would default to the global tracer provider.

* feedback: pass in tracer_provider directly rather than kwargs

* feedback: update changelog

* build: lint
This commit is contained in:
Jim Myers
2021-10-12 13:49:22 -04:00
committed by GitHub
parent 5105820fff
commit e8af7a3339
4 changed files with 66 additions and 22 deletions

View File

@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#713](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/713))
- `opentelemetry-sdk-extension-aws` Move AWS X-Ray Propagator into its own `opentelemetry-propagators-aws` package
([#720](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/720))
- `opentelemetry-instrumentation-sqlalchemy` Respect provided tracer provider when instrumenting SQLAlchemy
([#728](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/728))
### Changed

View File

@ -88,20 +88,23 @@ class SQLAlchemyInstrumentor(BaseInstrumentor):
Returns:
An instrumented engine if passed in as an argument, None otherwise.
"""
_w("sqlalchemy", "create_engine", _wrap_create_engine)
_w("sqlalchemy.engine", "create_engine", _wrap_create_engine)
tracer_provider = kwargs.get("tracer_provider")
_w("sqlalchemy", "create_engine", _wrap_create_engine(tracer_provider))
_w(
"sqlalchemy.engine",
"create_engine",
_wrap_create_engine(tracer_provider),
)
if parse_version(sqlalchemy.__version__).release >= (1, 4):
_w(
"sqlalchemy.ext.asyncio",
"create_async_engine",
_wrap_create_async_engine,
_wrap_create_async_engine(tracer_provider),
)
if kwargs.get("engine") is not None:
return EngineTracer(
_get_tracer(
kwargs.get("engine"), kwargs.get("tracer_provider")
),
_get_tracer(kwargs.get("engine"), tracer_provider),
kwargs.get("engine"),
)
return None

View File

@ -42,24 +42,30 @@ def _get_tracer(engine, tracer_provider=None):
)
# pylint: disable=unused-argument
def _wrap_create_async_engine(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine), engine.sync_engine)
return engine
def _wrap_create_async_engine(tracer_provider=None):
# pylint: disable=unused-argument
def _wrap_create_async_engine_internal(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine, tracer_provider), engine.sync_engine)
return engine
return _wrap_create_async_engine_internal
# pylint: disable=unused-argument
def _wrap_create_engine(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine), engine)
return engine
def _wrap_create_engine(tracer_provider=None):
# pylint: disable=unused-argument
def _wrap_create_engine_internal(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine, tracer_provider), engine)
return engine
return _wrap_create_engine_internal
class EngineTracer:

View File

@ -20,6 +20,8 @@ from sqlalchemy import create_engine
from opentelemetry import trace
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider, export
from opentelemetry.test.test_base import TestBase
@ -95,6 +97,37 @@ class TestSqlalchemyInstrumentation(TestBase):
self.assertEqual(spans[0].name, "SELECT :memory:")
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)
def test_custom_tracer_provider(self):
provider = TracerProvider(
resource=Resource.create(
{
"service.name": "test",
"deployment.environment": "env",
"service.version": "1234",
},
),
)
provider.add_span_processor(
export.SimpleSpanProcessor(self.memory_exporter)
)
SQLAlchemyInstrumentor().instrument(tracer_provider=provider)
from sqlalchemy import create_engine # pylint: disable-all
engine = create_engine("sqlite:///:memory:")
cnx = engine.connect()
cnx.execute("SELECT 1 + 1;").fetchall()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].resource.attributes["service.name"], "test")
self.assertEqual(
spans[0].resource.attributes["deployment.environment"], "env"
)
self.assertEqual(
spans[0].resource.attributes["service.version"], "1234"
)
@pytest.mark.skipif(
not sqlalchemy.__version__.startswith("1.4"),
reason="only run async tests for 1.4",