Files
2020-04-08 10:39:44 -07:00

220 lines
6.5 KiB
Python

import sys
# 3p
from ddtrace.vendor import wrapt
import pymemcache
from pymemcache.client.base import Client
from pymemcache.exceptions import (
MemcacheClientError,
MemcacheServerError,
MemcacheUnknownCommandError,
MemcacheUnknownError,
MemcacheIllegalInputError,
)
# project
from ...constants import ANALYTICS_SAMPLE_RATE_KEY
from ...compat import reraise
from ...ext import SpanTypes, net, memcached as memcachedx
from ...internal.logger import get_logger
from ...pin import Pin
from ...settings import config
log = get_logger(__name__)
# keep a reference to the original unpatched clients
_Client = Client
class WrappedClient(wrapt.ObjectProxy):
"""Wrapper providing patched methods of a pymemcache Client.
Relevant connection information is obtained during initialization and
attached to each span.
Keys are tagged in spans for methods that act upon a key.
"""
def __init__(self, *args, **kwargs):
c = _Client(*args, **kwargs)
super(WrappedClient, self).__init__(c)
# tags to apply to each span generated by this client
tags = _get_address_tags(*args, **kwargs)
parent_pin = Pin.get_from(pymemcache)
if parent_pin:
pin = parent_pin.clone(tags=tags)
else:
pin = Pin(tags=tags)
# attach the pin onto this instance
pin.onto(self)
def set(self, *args, **kwargs):
return self._traced_cmd('set', *args, **kwargs)
def set_many(self, *args, **kwargs):
return self._traced_cmd('set_many', *args, **kwargs)
def add(self, *args, **kwargs):
return self._traced_cmd('add', *args, **kwargs)
def replace(self, *args, **kwargs):
return self._traced_cmd('replace', *args, **kwargs)
def append(self, *args, **kwargs):
return self._traced_cmd('append', *args, **kwargs)
def prepend(self, *args, **kwargs):
return self._traced_cmd('prepend', *args, **kwargs)
def cas(self, *args, **kwargs):
return self._traced_cmd('cas', *args, **kwargs)
def get(self, *args, **kwargs):
return self._traced_cmd('get', *args, **kwargs)
def get_many(self, *args, **kwargs):
return self._traced_cmd('get_many', *args, **kwargs)
def gets(self, *args, **kwargs):
return self._traced_cmd('gets', *args, **kwargs)
def gets_many(self, *args, **kwargs):
return self._traced_cmd('gets_many', *args, **kwargs)
def delete(self, *args, **kwargs):
return self._traced_cmd('delete', *args, **kwargs)
def delete_many(self, *args, **kwargs):
return self._traced_cmd('delete_many', *args, **kwargs)
def incr(self, *args, **kwargs):
return self._traced_cmd('incr', *args, **kwargs)
def decr(self, *args, **kwargs):
return self._traced_cmd('decr', *args, **kwargs)
def touch(self, *args, **kwargs):
return self._traced_cmd('touch', *args, **kwargs)
def stats(self, *args, **kwargs):
return self._traced_cmd('stats', *args, **kwargs)
def version(self, *args, **kwargs):
return self._traced_cmd('version', *args, **kwargs)
def flush_all(self, *args, **kwargs):
return self._traced_cmd('flush_all', *args, **kwargs)
def quit(self, *args, **kwargs):
return self._traced_cmd('quit', *args, **kwargs)
def set_multi(self, *args, **kwargs):
"""set_multi is an alias for set_many"""
return self._traced_cmd('set_many', *args, **kwargs)
def get_multi(self, *args, **kwargs):
"""set_multi is an alias for set_many"""
return self._traced_cmd('get_many', *args, **kwargs)
def _traced_cmd(self, method_name, *args, **kwargs):
"""Run and trace the given command.
Any pymemcache exception is caught and span error information is
set. The exception is then reraised for the application to handle
appropriately.
Relevant tags are set in the span.
"""
method = getattr(self.__wrapped__, method_name)
p = Pin.get_from(self)
# if the pin does not exist or is not enabled, shortcut
if not p or not p.enabled():
return method(*args, **kwargs)
with p.tracer.trace(
memcachedx.CMD,
service=p.service,
resource=method_name,
span_type=SpanTypes.CACHE,
) as span:
# set analytics sample rate
span.set_tag(
ANALYTICS_SAMPLE_RATE_KEY,
config.pymemcache.get_analytics_sample_rate()
)
# try to set relevant tags, catch any exceptions so we don't mess
# with the application
try:
span.set_tags(p.tags)
vals = _get_query_string(args)
query = '{}{}{}'.format(method_name, ' ' if vals else '', vals)
span.set_tag(memcachedx.QUERY, query)
except Exception:
log.debug('Error setting relevant pymemcache tags')
try:
return method(*args, **kwargs)
except (
MemcacheClientError,
MemcacheServerError,
MemcacheUnknownCommandError,
MemcacheUnknownError,
MemcacheIllegalInputError,
):
(typ, val, tb) = sys.exc_info()
span.set_exc_info(typ, val, tb)
reraise(typ, val, tb)
def _get_address_tags(*args, **kwargs):
"""Attempt to get host and port from args passed to Client initializer."""
tags = {}
try:
if len(args):
host, port = args[0]
tags[net.TARGET_HOST] = host
tags[net.TARGET_PORT] = port
except Exception:
log.debug('Error collecting client address tags')
return tags
def _get_query_string(args):
"""Return the query values given the arguments to a pymemcache command.
If there are multiple query values, they are joined together
space-separated.
"""
keys = ''
# shortcut if no args
if not args:
return keys
# pull out the first arg which will contain any key
arg = args[0]
# if we get a dict, convert to list of keys
if type(arg) is dict:
arg = list(arg)
if type(arg) is str:
keys = arg
elif type(arg) is bytes:
keys = arg.decode()
elif type(arg) is list and len(arg):
if type(arg[0]) is str:
keys = ' '.join(arg)
elif type(arg[0]) is bytes:
keys = b' '.join(arg).decode()
return keys