Files
2020-04-08 10:39:44 -07:00

459 lines
16 KiB
Python

# stdlib
import contextlib
import logging
import unittest
from threading import Event
# 3p
from cassandra.cluster import Cluster, ResultSet
from cassandra.query import BatchStatement, SimpleStatement
# project
from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY
from ddtrace.contrib.cassandra.patch import patch, unpatch
from ddtrace.contrib.cassandra.session import get_traced_cassandra, SERVICE
from ddtrace.ext import net, cassandra as cassx, errors
from ddtrace import config, Pin
# testing
from tests.contrib.config import CASSANDRA_CONFIG
from tests.opentracer.utils import init_tracer
from tests.test_tracer import get_dummy_tracer
# Oftentimes our tests fails because Cassandra connection timeouts during keyspace drop. Slowness in keyspace drop
# is known and is due to 'auto_snapshot' configuration. In our test env we should disable it, but the official cassandra
# image that we are using only allows us to configure a few configs:
# https://github.com/docker-library/cassandra/blob/4474c6c5cc2a81ee57c5615aae00555fca7e26a6/3.11/docker-entrypoint.sh#L51
# So for now we just increase the timeout, if this is not enough we may want to extend the official image with our own
# custom image.
CONNECTION_TIMEOUT_SECS = 20 # override the default value of 5
logging.getLogger('cassandra').setLevel(logging.INFO)
def setUpModule():
# skip all the modules if the Cluster is not available
if not Cluster:
raise unittest.SkipTest('cassandra.cluster.Cluster is not available.')
# create the KEYSPACE for this test module
cluster = Cluster(port=CASSANDRA_CONFIG['port'], connect_timeout=CONNECTION_TIMEOUT_SECS)
session = cluster.connect()
session.execute('DROP KEYSPACE IF EXISTS test', timeout=10)
session.execute(
"CREATE KEYSPACE if not exists test WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor': 1};"
)
session.execute('CREATE TABLE if not exists test.person (name text PRIMARY KEY, age int, description text)')
session.execute('CREATE TABLE if not exists test.person_write (name text PRIMARY KEY, age int, description text)')
session.execute("INSERT INTO test.person (name, age, description) VALUES ('Cassandra', 100, 'A cruel mistress')")
session.execute(
"INSERT INTO test.person (name, age, description) VALUES ('Athena', 100, 'Whose shield is thunder')"
)
session.execute("INSERT INTO test.person (name, age, description) VALUES ('Calypso', 100, 'Softly-braided nymph')")
def tearDownModule():
# destroy the KEYSPACE
cluster = Cluster(port=CASSANDRA_CONFIG['port'], connect_timeout=CONNECTION_TIMEOUT_SECS)
session = cluster.connect()
session.execute('DROP TABLE IF EXISTS test.person')
session.execute('DROP TABLE IF EXISTS test.person_write')
session.execute('DROP KEYSPACE IF EXISTS test', timeout=10)
class CassandraBase(object):
"""
Needs a running Cassandra
"""
TEST_QUERY = "SELECT * from test.person WHERE name = 'Cassandra'"
TEST_QUERY_PAGINATED = 'SELECT * from test.person'
TEST_KEYSPACE = 'test'
TEST_PORT = CASSANDRA_CONFIG['port']
TEST_SERVICE = 'test-cassandra'
def _traced_session(self):
# implement me
pass
@contextlib.contextmanager
def override_config(self, integration, values):
"""
Temporarily override an integration configuration value
>>> with self.override_config('flask', dict(service_name='test-service')):
... # Your test
"""
options = getattr(config, integration)
original = dict(
(key, options.get(key))
for key in values.keys()
)
options.update(values)
try:
yield
finally:
options.update(original)
def setUp(self):
self.cluster = Cluster(port=CASSANDRA_CONFIG['port'])
self.session = self.cluster.connect()
def _assert_result_correct(self, result):
assert len(result.current_rows) == 1
for r in result:
assert r.name == 'Cassandra'
assert r.age == 100
assert r.description == 'A cruel mistress'
def _test_query_base(self, execute_fn):
session, tracer = self._traced_session()
writer = tracer.writer
result = execute_fn(session, self.TEST_QUERY)
self._assert_result_correct(result)
spans = writer.pop()
assert spans, spans
# another for the actual query
assert len(spans) == 1
query = spans[0]
assert query.service == self.TEST_SERVICE
assert query.resource == self.TEST_QUERY
assert query.span_type == 'cassandra'
assert query.get_tag(cassx.KEYSPACE) == self.TEST_KEYSPACE
assert query.get_metric(net.TARGET_PORT) == self.TEST_PORT
assert query.get_metric(cassx.ROW_COUNT) == 1
assert query.get_tag(cassx.PAGE_NUMBER) is None
assert query.get_tag(cassx.PAGINATED) == 'False'
assert query.get_tag(net.TARGET_HOST) == '127.0.0.1'
# confirm no analytics sample rate set by default
assert query.get_metric(ANALYTICS_SAMPLE_RATE_KEY) is None
def test_query(self):
def execute_fn(session, query):
return session.execute(query)
self._test_query_base(execute_fn)
def test_query_analytics_with_rate(self):
with self.override_config(
'cassandra',
dict(analytics_enabled=True, analytics_sample_rate=0.5)
):
session, tracer = self._traced_session()
session.execute(self.TEST_QUERY)
writer = tracer.writer
spans = writer.pop()
assert spans, spans
# another for the actual query
assert len(spans) == 1
query = spans[0]
# confirm no analytics sample rate set by default
assert query.get_metric(ANALYTICS_SAMPLE_RATE_KEY) == 0.5
def test_query_analytics_without_rate(self):
with self.override_config(
'cassandra',
dict(analytics_enabled=True)
):
session, tracer = self._traced_session()
session.execute(self.TEST_QUERY)
writer = tracer.writer
spans = writer.pop()
assert spans, spans
# another for the actual query
assert len(spans) == 1
query = spans[0]
# confirm no analytics sample rate set by default
assert query.get_metric(ANALYTICS_SAMPLE_RATE_KEY) == 1.0
def test_query_ot(self):
"""Ensure that cassandra works with the opentracer."""
def execute_fn(session, query):
return session.execute(query)
session, tracer = self._traced_session()
ot_tracer = init_tracer('cass_svc', tracer)
writer = tracer.writer
with ot_tracer.start_active_span('cass_op'):
result = execute_fn(session, self.TEST_QUERY)
self._assert_result_correct(result)
spans = writer.pop()
assert spans, spans
# another for the actual query
assert len(spans) == 2
ot_span, dd_span = spans
# confirm parenting
assert ot_span.parent_id is None
assert dd_span.parent_id == ot_span.span_id
assert ot_span.name == 'cass_op'
assert ot_span.service == 'cass_svc'
assert dd_span.service == self.TEST_SERVICE
assert dd_span.resource == self.TEST_QUERY
assert dd_span.span_type == 'cassandra'
assert dd_span.get_tag(cassx.KEYSPACE) == self.TEST_KEYSPACE
assert dd_span.get_metric(net.TARGET_PORT) == self.TEST_PORT
assert dd_span.get_metric(cassx.ROW_COUNT) == 1
assert dd_span.get_tag(cassx.PAGE_NUMBER) is None
assert dd_span.get_tag(cassx.PAGINATED) == 'False'
assert dd_span.get_tag(net.TARGET_HOST) == '127.0.0.1'
def test_query_async(self):
def execute_fn(session, query):
event = Event()
result = []
future = session.execute_async(query)
def callback(results):
result.append(ResultSet(future, results))
event.set()
future.add_callback(callback)
event.wait()
return result[0]
self._test_query_base(execute_fn)
def test_query_async_clearing_callbacks(self):
def execute_fn(session, query):
future = session.execute_async(query)
future.clear_callbacks()
return future.result()
self._test_query_base(execute_fn)
def test_span_is_removed_from_future(self):
session, tracer = self._traced_session()
future = session.execute_async(self.TEST_QUERY)
future.result()
span = getattr(future, '_ddtrace_current_span', None)
assert span is None
def test_paginated_query(self):
session, tracer = self._traced_session()
writer = tracer.writer
statement = SimpleStatement(self.TEST_QUERY_PAGINATED, fetch_size=1)
result = session.execute(statement)
# iterate over all pages
results = list(result)
assert len(results) == 3
spans = writer.pop()
assert spans, spans
# There are 4 spans for 3 results since the driver makes a request with
# no result to check that it has reached the last page
assert len(spans) == 4
for i in range(4):
query = spans[i]
assert query.service == self.TEST_SERVICE
assert query.resource == self.TEST_QUERY_PAGINATED
assert query.span_type == 'cassandra'
assert query.get_tag(cassx.KEYSPACE) == self.TEST_KEYSPACE
assert query.get_metric(net.TARGET_PORT) == self.TEST_PORT
if i == 3:
assert query.get_metric(cassx.ROW_COUNT) == 0
else:
assert query.get_metric(cassx.ROW_COUNT) == 1
assert query.get_tag(net.TARGET_HOST) == '127.0.0.1'
assert query.get_tag(cassx.PAGINATED) == 'True'
assert query.get_metric(cassx.PAGE_NUMBER) == i + 1
def test_trace_with_service(self):
session, tracer = self._traced_session()
writer = tracer.writer
session.execute(self.TEST_QUERY)
spans = writer.pop()
assert spans
assert len(spans) == 1
query = spans[0]
assert query.service == self.TEST_SERVICE
def test_trace_error(self):
session, tracer = self._traced_session()
writer = tracer.writer
try:
session.execute('select * from test.i_dont_exist limit 1')
except Exception:
pass
else:
assert 0
spans = writer.pop()
assert spans
query = spans[0]
assert query.error == 1
for k in (errors.ERROR_MSG, errors.ERROR_TYPE):
assert query.get_tag(k)
def test_bound_statement(self):
session, tracer = self._traced_session()
writer = tracer.writer
query = 'INSERT INTO test.person_write (name, age, description) VALUES (?, ?, ?)'
prepared = session.prepare(query)
session.execute(prepared, ('matt', 34, 'can'))
prepared = session.prepare(query)
bound_stmt = prepared.bind(('leo', 16, 'fr'))
session.execute(bound_stmt)
spans = writer.pop()
assert len(spans) == 2
for s in spans:
assert s.resource == query
def test_batch_statement(self):
session, tracer = self._traced_session()
writer = tracer.writer
batch = BatchStatement()
batch.add(
SimpleStatement('INSERT INTO test.person_write (name, age, description) VALUES (%s, %s, %s)'),
('Joe', 1, 'a'),
)
batch.add(
SimpleStatement('INSERT INTO test.person_write (name, age, description) VALUES (%s, %s, %s)'),
('Jane', 2, 'b'),
)
session.execute(batch)
spans = writer.pop()
assert len(spans) == 1
s = spans[0]
assert s.resource == 'BatchStatement'
assert s.get_metric('cassandra.batch_size') == 2
assert 'test.person' in s.get_tag('cassandra.query')
def test_batched_bound_statement(self):
session, tracer = self._traced_session()
writer = tracer.writer
batch = BatchStatement()
prepared_statement = session.prepare('INSERT INTO test.person_write (name, age, description) VALUES (?, ?, ?)')
batch.add(
prepared_statement.bind(('matt', 34, 'can'))
)
session.execute(batch)
spans = writer.pop()
assert len(spans) == 1
s = spans[0]
assert s.resource == 'BatchStatement'
assert s.get_tag('cassandra.query') == ''
class TestCassPatchDefault(unittest.TestCase, CassandraBase):
"""Test Cassandra instrumentation with patching and default configuration"""
TEST_SERVICE = SERVICE
def tearDown(self):
unpatch()
def setUp(self):
CassandraBase.setUp(self)
patch()
def _traced_session(self):
tracer = get_dummy_tracer()
Pin.get_from(self.cluster).clone(tracer=tracer).onto(self.cluster)
return self.cluster.connect(self.TEST_KEYSPACE), tracer
class TestCassPatchAll(TestCassPatchDefault):
"""Test Cassandra instrumentation with patching and custom service on all clusters"""
TEST_SERVICE = 'test-cassandra-patch-all'
def tearDown(self):
unpatch()
def setUp(self):
CassandraBase.setUp(self)
patch()
def _traced_session(self):
tracer = get_dummy_tracer()
# pin the global Cluster to test if they will conflict
Pin(service=self.TEST_SERVICE, tracer=tracer).onto(Cluster)
self.cluster = Cluster(port=CASSANDRA_CONFIG['port'])
return self.cluster.connect(self.TEST_KEYSPACE), tracer
class TestCassPatchOne(TestCassPatchDefault):
"""Test Cassandra instrumentation with patching and custom service on one cluster"""
TEST_SERVICE = 'test-cassandra-patch-one'
def tearDown(self):
unpatch()
def setUp(self):
CassandraBase.setUp(self)
patch()
def _traced_session(self):
tracer = get_dummy_tracer()
# pin the global Cluster to test if they will conflict
Pin(service='not-%s' % self.TEST_SERVICE).onto(Cluster)
self.cluster = Cluster(port=CASSANDRA_CONFIG['port'])
Pin(service=self.TEST_SERVICE, tracer=tracer).onto(self.cluster)
return self.cluster.connect(self.TEST_KEYSPACE), tracer
def test_patch_unpatch(self):
# Test patch idempotence
patch()
patch()
tracer = get_dummy_tracer()
Pin.get_from(Cluster).clone(tracer=tracer).onto(Cluster)
session = Cluster(port=CASSANDRA_CONFIG['port']).connect(self.TEST_KEYSPACE)
session.execute(self.TEST_QUERY)
spans = tracer.writer.pop()
assert spans, spans
assert len(spans) == 1
# Test unpatch
unpatch()
session = Cluster(port=CASSANDRA_CONFIG['port']).connect(self.TEST_KEYSPACE)
session.execute(self.TEST_QUERY)
spans = tracer.writer.pop()
assert not spans, spans
# Test patch again
patch()
Pin.get_from(Cluster).clone(tracer=tracer).onto(Cluster)
session = Cluster(port=CASSANDRA_CONFIG['port']).connect(self.TEST_KEYSPACE)
session.execute(self.TEST_QUERY)
spans = tracer.writer.pop()
assert spans, spans
def test_backwards_compat_get_traced_cassandra():
cluster = get_traced_cassandra()
session = cluster(port=CASSANDRA_CONFIG['port']).connect()
session.execute('drop table if exists test.person')