Files
2020-04-08 10:39:44 -07:00

357 lines
12 KiB
Python

import time
import concurrent
from ddtrace.contrib.futures import patch, unpatch
from tests.opentracer.utils import init_tracer
from ...base import BaseTracerTestCase
class PropagationTestCase(BaseTracerTestCase):
"""Ensures the Context Propagation works between threads
when the ``futures`` library is used, or when the
``concurrent`` module is available (Python 3 only)
"""
def setUp(self):
super(PropagationTestCase, self).setUp()
# instrument ``concurrent``
patch()
def tearDown(self):
# remove instrumentation
unpatch()
super(PropagationTestCase, self).tearDown()
def test_propagation(self):
# it must propagate the tracing context if available
def fn():
# an active context must be available
# DEV: With `ContextManager` `.active()` will never be `None`
self.assertIsNotNone(self.tracer.context_provider.active())
with self.tracer.trace('executor.thread'):
return 42
with self.override_global_tracer():
with self.tracer.trace('main.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future = executor.submit(fn)
result = future.result()
# assert the right result
self.assertEqual(result, 42)
# the trace must be completed
self.assert_structure(
dict(name='main.thread'),
(
dict(name='executor.thread'),
),
)
def test_propagation_with_params(self):
# instrumentation must proxy arguments if available
def fn(value, key=None):
# an active context must be available
# DEV: With `ThreadLocalContext` `.active()` will never be `None`
self.assertIsNotNone(self.tracer.context_provider.active())
with self.tracer.trace('executor.thread'):
return value, key
with self.override_global_tracer():
with self.tracer.trace('main.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future = executor.submit(fn, 42, 'CheeseShop')
value, key = future.result()
# assert the right result
self.assertEqual(value, 42)
self.assertEqual(key, 'CheeseShop')
# the trace must be completed
self.assert_structure(
dict(name='main.thread'),
(
dict(name='executor.thread'),
),
)
def test_disabled_instrumentation(self):
# it must not propagate if the module is disabled
unpatch()
def fn():
# an active context must be available
# DEV: With `ThreadLocalContext` `.active()` will never be `None`
self.assertIsNotNone(self.tracer.context_provider.active())
with self.tracer.trace('executor.thread'):
return 42
with self.override_global_tracer():
with self.tracer.trace('main.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future = executor.submit(fn)
result = future.result()
# assert the right result
self.assertEqual(result, 42)
# we provide two different traces
self.assert_span_count(2)
# Retrieve the root spans (no parents)
# DEV: Results are sorted based on root span start time
traces = self.get_root_spans()
self.assertEqual(len(traces), 2)
traces[0].assert_structure(dict(name='main.thread'))
traces[1].assert_structure(dict(name='executor.thread'))
def test_double_instrumentation(self):
# double instrumentation must not happen
patch()
def fn():
with self.tracer.trace('executor.thread'):
return 42
with self.override_global_tracer():
with self.tracer.trace('main.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future = executor.submit(fn)
result = future.result()
# assert the right result
self.assertEqual(result, 42)
# the trace must be completed
self.assert_structure(
dict(name='main.thread'),
(
dict(name='executor.thread'),
),
)
def test_no_parent_span(self):
def fn():
with self.tracer.trace('executor.thread'):
return 42
with self.override_global_tracer():
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future = executor.submit(fn)
result = future.result()
# assert the right result
self.assertEqual(result, 42)
# the trace must be completed
self.assert_structure(dict(name='executor.thread'))
def test_multiple_futures(self):
def fn():
with self.tracer.trace('executor.thread'):
return 42
with self.override_global_tracer():
with self.tracer.trace('main.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(fn) for _ in range(4)]
for future in futures:
result = future.result()
# assert the right result
self.assertEqual(result, 42)
# the trace must be completed
self.assert_structure(
dict(name='main.thread'),
(
dict(name='executor.thread'),
dict(name='executor.thread'),
dict(name='executor.thread'),
dict(name='executor.thread'),
),
)
def test_multiple_futures_no_parent(self):
def fn():
with self.tracer.trace('executor.thread'):
return 42
with self.override_global_tracer():
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(fn) for _ in range(4)]
for future in futures:
result = future.result()
# assert the right result
self.assertEqual(result, 42)
# the trace must be completed
self.assert_span_count(4)
traces = self.get_root_spans()
self.assertEqual(len(traces), 4)
for trace in traces:
trace.assert_structure(dict(name='executor.thread'))
def test_nested_futures(self):
def fn2():
with self.tracer.trace('nested.thread'):
return 42
def fn():
with self.tracer.trace('executor.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future = executor.submit(fn2)
result = future.result()
self.assertEqual(result, 42)
return result
with self.override_global_tracer():
with self.tracer.trace('main.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future = executor.submit(fn)
result = future.result()
# assert the right result
self.assertEqual(result, 42)
# the trace must be completed
self.assert_span_count(3)
self.assert_structure(
dict(name='main.thread'),
(
(
dict(name='executor.thread'),
(
dict(name='nested.thread'),
),
),
),
)
def test_multiple_nested_futures(self):
def fn2():
with self.tracer.trace('nested.thread'):
return 42
def fn():
with self.tracer.trace('executor.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
futures = [executor.submit(fn2) for _ in range(4)]
for future in futures:
result = future.result()
self.assertEqual(result, 42)
return result
with self.override_global_tracer():
with self.tracer.trace('main.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
futures = [executor.submit(fn) for _ in range(4)]
for future in futures:
result = future.result()
# assert the right result
self.assertEqual(result, 42)
# the trace must be completed
self.assert_structure(
dict(name='main.thread'),
(
(
dict(name='executor.thread'),
(
dict(name='nested.thread'),
) * 4,
),
) * 4,
)
def test_multiple_nested_futures_no_parent(self):
def fn2():
with self.tracer.trace('nested.thread'):
return 42
def fn():
with self.tracer.trace('executor.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
futures = [executor.submit(fn2) for _ in range(4)]
for future in futures:
result = future.result()
self.assertEqual(result, 42)
return result
with self.override_global_tracer():
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
futures = [executor.submit(fn) for _ in range(4)]
for future in futures:
result = future.result()
# assert the right result
self.assertEqual(result, 42)
# the trace must be completed
traces = self.get_root_spans()
self.assertEqual(len(traces), 4)
for trace in traces:
trace.assert_structure(
dict(name='executor.thread'),
(
dict(name='nested.thread'),
) * 4,
)
def test_send_trace_when_finished(self):
# it must send the trace only when all threads are finished
def fn():
with self.tracer.trace('executor.thread'):
# wait before returning
time.sleep(0.05)
return 42
with self.override_global_tracer():
with self.tracer.trace('main.thread'):
# don't wait for the execution
executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
future = executor.submit(fn)
time.sleep(0.01)
# assert main thread span is fniished first
self.assert_span_count(1)
self.assert_structure(dict(name='main.thread'))
# then wait for the second thread and send the trace
result = future.result()
self.assertEqual(result, 42)
self.assert_span_count(2)
self.assert_structure(
dict(name='main.thread'),
(
dict(name='executor.thread'),
),
)
def test_propagation_ot(self):
"""OpenTracing version of test_propagation."""
# it must propagate the tracing context if available
ot_tracer = init_tracer('my_svc', self.tracer)
def fn():
# an active context must be available
self.assertTrue(self.tracer.context_provider.active() is not None)
with self.tracer.trace('executor.thread'):
return 42
with self.override_global_tracer():
with ot_tracer.start_active_span('main.thread'):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future = executor.submit(fn)
result = future.result()
# assert the right result
self.assertEqual(result, 42)
# the trace must be completed
self.assert_structure(
dict(name='main.thread'),
(
dict(name='executor.thread'),
),
)