Fix async redis clients tracing (#1830)

* Fix async redis clients tracing

* Update changelog

* Add functional integration tests and fix linting issues

---------

Co-authored-by: Shalev Roda <65566801+shalevr@users.noreply.github.com>
This commit is contained in:
Vivanov98
2023-06-25 13:03:54 +01:00
committed by GitHub
parent e70437a36e
commit cd6b024327
4 changed files with 270 additions and 34 deletions

View File

@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Fix async redis clients not being traced correctly ([#1830](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1830))
- Make Flask request span attributes available for `start_span`. - Make Flask request span attributes available for `start_span`.
([#1784](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1784)) ([#1784](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1784))
- Fix falcon instrumentation's usage of Span Status to only set the description if the status code is ERROR. - Fix falcon instrumentation's usage of Span Status to only set the description if the status code is ERROR.

View File

@ -136,6 +136,43 @@ def _set_connection_attributes(span, conn):
span.set_attribute(key, value) span.set_attribute(key, value)
def _build_span_name(instance, cmd_args):
if len(cmd_args) > 0 and cmd_args[0]:
name = cmd_args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
return name
def _build_span_meta_data_for_pipeline(instance):
try:
command_stack = (
instance.command_stack
if hasattr(instance, "command_stack")
else instance._command_stack
)
cmds = [
_format_command_args(c.args if hasattr(c, "args") else c[0])
for c in command_stack
]
resource = "\n".join(cmds)
span_name = " ".join(
[
(c.args[0] if hasattr(c, "args") else c[0][0])
for c in command_stack
]
)
except (AttributeError, IndexError):
command_stack = []
resource = ""
span_name = ""
return command_stack, resource, span_name
# pylint: disable=R0915
def _instrument( def _instrument(
tracer, tracer,
request_hook: _RequestHookT = None, request_hook: _RequestHookT = None,
@ -143,11 +180,8 @@ def _instrument(
): ):
def _traced_execute_command(func, instance, args, kwargs): def _traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args) query = _format_command_args(args)
name = _build_span_name(instance, args)
if len(args) > 0 and args[0]:
name = args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
with tracer.start_as_current_span( with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT name, kind=trace.SpanKind.CLIENT
) as span: ) as span:
@ -163,31 +197,11 @@ def _instrument(
return response return response
def _traced_execute_pipeline(func, instance, args, kwargs): def _traced_execute_pipeline(func, instance, args, kwargs):
try: (
command_stack = ( command_stack,
instance.command_stack resource,
if hasattr(instance, "command_stack") span_name,
else instance._command_stack ) = _build_span_meta_data_for_pipeline(instance)
)
cmds = [
_format_command_args(
c.args if hasattr(c, "args") else c[0],
)
for c in command_stack
]
resource = "\n".join(cmds)
span_name = " ".join(
[
(c.args[0] if hasattr(c, "args") else c[0][0])
for c in command_stack
]
)
except (AttributeError, IndexError):
command_stack = []
resource = ""
span_name = ""
with tracer.start_as_current_span( with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT span_name, kind=trace.SpanKind.CLIENT
@ -232,32 +246,72 @@ def _instrument(
"ClusterPipeline.execute", "ClusterPipeline.execute",
_traced_execute_pipeline, _traced_execute_pipeline,
) )
async def _async_traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args)
name = _build_span_name(instance, args)
with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
_set_connection_attributes(span, instance)
span.set_attribute("db.redis.args_length", len(args))
if callable(request_hook):
request_hook(span, instance, args, kwargs)
response = await func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response
async def _async_traced_execute_pipeline(func, instance, args, kwargs):
(
command_stack,
resource,
span_name,
) = _build_span_meta_data_for_pipeline(instance)
with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
_set_connection_attributes(span, instance)
span.set_attribute(
"db.redis.pipeline_length", len(command_stack)
)
response = await func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response
if redis.VERSION >= _REDIS_ASYNCIO_VERSION: if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
wrap_function_wrapper( wrap_function_wrapper(
"redis.asyncio", "redis.asyncio",
f"{redis_class}.execute_command", f"{redis_class}.execute_command",
_traced_execute_command, _async_traced_execute_command,
) )
wrap_function_wrapper( wrap_function_wrapper(
"redis.asyncio.client", "redis.asyncio.client",
f"{pipeline_class}.execute", f"{pipeline_class}.execute",
_traced_execute_pipeline, _async_traced_execute_pipeline,
) )
wrap_function_wrapper( wrap_function_wrapper(
"redis.asyncio.client", "redis.asyncio.client",
f"{pipeline_class}.immediate_execute_command", f"{pipeline_class}.immediate_execute_command",
_traced_execute_command, _async_traced_execute_command,
) )
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION: if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
wrap_function_wrapper( wrap_function_wrapper(
"redis.asyncio.cluster", "redis.asyncio.cluster",
"RedisCluster.execute_command", "RedisCluster.execute_command",
_traced_execute_command, _async_traced_execute_command,
) )
wrap_function_wrapper( wrap_function_wrapper(
"redis.asyncio.cluster", "redis.asyncio.cluster",
"ClusterPipeline.execute", "ClusterPipeline.execute",
_traced_execute_pipeline, _async_traced_execute_pipeline,
) )

View File

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
from unittest import mock from unittest import mock
import redis import redis
import redis.asyncio
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.instrumentation.redis import RedisInstrumentor from opentelemetry.instrumentation.redis import RedisInstrumentor
@ -21,6 +23,24 @@ from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind from opentelemetry.trace import SpanKind
class AsyncMock:
"""A sufficient async mock implementation.
Python 3.7 doesn't have an inbuilt async mock class, so this is used.
"""
def __init__(self):
self.mock = mock.Mock()
async def __call__(self, *args, **kwargs):
future = asyncio.Future()
future.set_result("random")
return future
def __getattr__(self, item):
return AsyncMock()
class TestRedis(TestBase): class TestRedis(TestBase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -87,6 +107,35 @@ class TestRedis(TestBase):
spans = self.memory_exporter.get_finished_spans() spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1) self.assertEqual(len(spans), 1)
def test_instrument_uninstrument_async_client_command(self):
redis_client = redis.asyncio.Redis()
with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.memory_exporter.clear()
# Test uninstrument
RedisInstrumentor().uninstrument()
with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)
self.memory_exporter.clear()
# Test instrument again
RedisInstrumentor().instrument()
with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
def test_response_hook(self): def test_response_hook(self):
redis_client = redis.Redis() redis_client = redis.Redis()
connection = redis.connection.Connection() connection = redis.connection.Connection()

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
from time import time_ns
import redis import redis
import redis.asyncio import redis.asyncio
@ -318,6 +319,29 @@ class TestAsyncRedisInstrument(TestBase):
) )
self.assertEqual(span.attributes.get("db.redis.args_length"), 2) self.assertEqual(span.attributes.get("db.redis.args_length"), 2)
def test_execute_command_traced_full_time(self):
"""Command should be traced for coroutine execution time, not creation time."""
coro_created_time = None
finish_time = None
async def pipeline_simple():
nonlocal coro_created_time
nonlocal finish_time
# delay coroutine creation from coroutine execution
coro = self.redis_client.get("foo")
coro_created_time = time_ns()
await coro
finish_time = time_ns()
async_call(pipeline_simple())
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertTrue(span.start_time > coro_created_time)
self.assertTrue(span.end_time < finish_time)
def test_pipeline_traced(self): def test_pipeline_traced(self):
async def pipeline_simple(): async def pipeline_simple():
async with self.redis_client.pipeline( async with self.redis_client.pipeline(
@ -340,6 +364,35 @@ class TestAsyncRedisInstrument(TestBase):
) )
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3) self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)
def test_pipeline_traced_full_time(self):
"""Command should be traced for coroutine execution time, not creation time."""
coro_created_time = None
finish_time = None
async def pipeline_simple():
async with self.redis_client.pipeline(
transaction=False
) as pipeline:
nonlocal coro_created_time
nonlocal finish_time
pipeline.set("blah", 32)
pipeline.rpush("foo", "éé")
pipeline.hgetall("xxx")
# delay coroutine creation from coroutine execution
coro = pipeline.execute()
coro_created_time = time_ns()
await coro
finish_time = time_ns()
async_call(pipeline_simple())
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertTrue(span.start_time > coro_created_time)
self.assertTrue(span.end_time < finish_time)
def test_pipeline_immediate(self): def test_pipeline_immediate(self):
async def pipeline_immediate(): async def pipeline_immediate():
async with self.redis_client.pipeline() as pipeline: async with self.redis_client.pipeline() as pipeline:
@ -359,6 +412,33 @@ class TestAsyncRedisInstrument(TestBase):
span.attributes.get(SpanAttributes.DB_STATEMENT), "SET ? ?" span.attributes.get(SpanAttributes.DB_STATEMENT), "SET ? ?"
) )
def test_pipeline_immediate_traced_full_time(self):
"""Command should be traced for coroutine execution time, not creation time."""
coro_created_time = None
finish_time = None
async def pipeline_simple():
async with self.redis_client.pipeline(
transaction=False
) as pipeline:
nonlocal coro_created_time
nonlocal finish_time
pipeline.set("a", 1)
# delay coroutine creation from coroutine execution
coro = pipeline.immediate_execute_command("SET", "b", 2)
coro_created_time = time_ns()
await coro
finish_time = time_ns()
async_call(pipeline_simple())
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertTrue(span.start_time > coro_created_time)
self.assertTrue(span.end_time < finish_time)
def test_parent(self): def test_parent(self):
"""Ensure OpenTelemetry works with redis.""" """Ensure OpenTelemetry works with redis."""
ot_tracer = trace.get_tracer("redis_svc") ot_tracer = trace.get_tracer("redis_svc")
@ -408,6 +488,29 @@ class TestAsyncRedisClusterInstrument(TestBase):
) )
self.assertEqual(span.attributes.get("db.redis.args_length"), 2) self.assertEqual(span.attributes.get("db.redis.args_length"), 2)
def test_execute_command_traced_full_time(self):
"""Command should be traced for coroutine execution time, not creation time."""
coro_created_time = None
finish_time = None
async def pipeline_simple():
nonlocal coro_created_time
nonlocal finish_time
# delay coroutine creation from coroutine execution
coro = self.redis_client.get("foo")
coro_created_time = time_ns()
await coro
finish_time = time_ns()
async_call(pipeline_simple())
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertTrue(span.start_time > coro_created_time)
self.assertTrue(span.end_time < finish_time)
def test_pipeline_traced(self): def test_pipeline_traced(self):
async def pipeline_simple(): async def pipeline_simple():
async with self.redis_client.pipeline( async with self.redis_client.pipeline(
@ -430,6 +533,35 @@ class TestAsyncRedisClusterInstrument(TestBase):
) )
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3) self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)
def test_pipeline_traced_full_time(self):
"""Command should be traced for coroutine execution time, not creation time."""
coro_created_time = None
finish_time = None
async def pipeline_simple():
async with self.redis_client.pipeline(
transaction=False
) as pipeline:
nonlocal coro_created_time
nonlocal finish_time
pipeline.set("blah", 32)
pipeline.rpush("foo", "éé")
pipeline.hgetall("xxx")
# delay coroutine creation from coroutine execution
coro = pipeline.execute()
coro_created_time = time_ns()
await coro
finish_time = time_ns()
async_call(pipeline_simple())
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertTrue(span.start_time > coro_created_time)
self.assertTrue(span.end_time < finish_time)
def test_parent(self): def test_parent(self):
"""Ensure OpenTelemetry works with redis.""" """Ensure OpenTelemetry works with redis."""
ot_tracer = trace.get_tracer("redis_svc") ot_tracer = trace.get_tracer("redis_svc")