diff --git a/.travis.yml b/.travis.yml index df95e89..ea1c62e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -59,7 +59,20 @@ jobs: script: - cd django-${DJANGO_VERSION}/tests/ - ./runtests.py --parallel=1 --settings=test_mysql - + - name: flake8 + python: "3.8" + install: + - pip install -U pip + - pip install flake8 + script: + - flake8 --ignore=E203,E501,W503 --max-line-length=88 . + - name: black + python: "3.8" + install: + - pip install -U pip + - pip install black + script: + - black --check --exclude=doc/ . #- &django_3_0 # <<: *django_2_2 # name: "Django 3.0 test (Python 3.8)" diff --git a/MySQLdb/__init__.py b/MySQLdb/__init__.py index fbb5e41..824acac 100644 --- a/MySQLdb/__init__.py +++ b/MySQLdb/__init__.py @@ -13,28 +13,57 @@ For information on how MySQLdb handles type conversion, see the MySQLdb.converters module. """ -from MySQLdb.release import __version__, version_info, __author__ +try: + from MySQLdb.release import version_info + from . import _mysql -from . import _mysql + assert version_info == _mysql.version_info +except Exception: + raise ImportError( + "this is MySQLdb version {}, but _mysql is version {!r}\n_mysql: {!r}".format( + version_info, _mysql.version_info, _mysql.__file__ + ) + ) -if version_info != _mysql.version_info: - raise ImportError("this is MySQLdb version %s, but _mysql is version %r\n_mysql: %r" % - (version_info, _mysql.version_info, _mysql.__file__)) -threadsafety = 1 -apilevel = "2.0" -paramstyle = "format" - -from ._mysql import * +from ._mysql import ( + NotSupportedError, + OperationalError, + get_client_info, + ProgrammingError, + Error, + InterfaceError, + debug, + IntegrityError, + string_literal, + MySQLError, + DataError, + escape, + escape_string, + DatabaseError, + InternalError, + Warning, +) from MySQLdb.constants import FIELD_TYPE -from MySQLdb.times import Date, Time, Timestamp, \ - DateFromTicks, TimeFromTicks, TimestampFromTicks +from MySQLdb.times import ( + Date, + Time, + Timestamp, + DateFromTicks, + TimeFromTicks, + TimestampFromTicks, +) try: frozenset except NameError: from sets import ImmutableSet as frozenset +threadsafety = 1 +apilevel = "2.0" +paramstyle = "format" + + class DBAPISet(frozenset): """A special type of set for which A == x is true if A is a DBAPISet and x is a member of that set.""" @@ -45,49 +74,106 @@ class DBAPISet(frozenset): return other in self -STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, - FIELD_TYPE.VAR_STRING]) -BINARY = DBAPISet([FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB, - FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.TINY_BLOB]) -NUMBER = DBAPISet([FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT, - FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG, - FIELD_TYPE.TINY, FIELD_TYPE.YEAR, FIELD_TYPE.NEWDECIMAL]) -DATE = DBAPISet([FIELD_TYPE.DATE]) -TIME = DBAPISet([FIELD_TYPE.TIME]) +STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]) +BINARY = DBAPISet( + [ + FIELD_TYPE.BLOB, + FIELD_TYPE.LONG_BLOB, + FIELD_TYPE.MEDIUM_BLOB, + FIELD_TYPE.TINY_BLOB, + ] +) +NUMBER = DBAPISet( + [ + FIELD_TYPE.DECIMAL, + FIELD_TYPE.DOUBLE, + FIELD_TYPE.FLOAT, + FIELD_TYPE.INT24, + FIELD_TYPE.LONG, + FIELD_TYPE.LONGLONG, + FIELD_TYPE.TINY, + FIELD_TYPE.YEAR, + FIELD_TYPE.NEWDECIMAL, + ] +) +DATE = DBAPISet([FIELD_TYPE.DATE]) +TIME = DBAPISet([FIELD_TYPE.TIME]) TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME]) -DATETIME = TIMESTAMP -ROWID = DBAPISet() +DATETIME = TIMESTAMP +ROWID = DBAPISet() + def test_DBAPISet_set_equality(): assert STRING == STRING + def test_DBAPISet_set_inequality(): assert STRING != NUMBER + def test_DBAPISet_set_equality_membership(): assert FIELD_TYPE.VAR_STRING == STRING + def test_DBAPISet_set_inequality_membership(): assert FIELD_TYPE.DATE != STRING + def Binary(x): return bytes(x) + def Connect(*args, **kwargs): """Factory function for connections.Connection.""" from MySQLdb.connections import Connection + return Connection(*args, **kwargs) + connect = Connection = Connect -__all__ = [ 'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', - 'Date', 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', - 'TimestampFromTicks', 'DataError', 'DatabaseError', 'Error', - 'FIELD_TYPE', 'IntegrityError', 'InterfaceError', 'InternalError', - 'MySQLError', 'NUMBER', 'NotSupportedError', 'DBAPISet', - 'OperationalError', 'ProgrammingError', 'ROWID', 'STRING', 'TIME', - 'TIMESTAMP', 'Warning', 'apilevel', 'connect', 'connections', - 'constants', 'converters', 'cursors', 'debug', 'escape', - 'escape_string', 'get_client_info', - 'paramstyle', 'string_literal', 'threadsafety', 'version_info'] - +__all__ = [ + "BINARY", + "Binary", + "Connect", + "Connection", + "DATE", + "Date", + "Time", + "Timestamp", + "DateFromTicks", + "TimeFromTicks", + "TimestampFromTicks", + "DataError", + "DatabaseError", + "Error", + "FIELD_TYPE", + "IntegrityError", + "InterfaceError", + "InternalError", + "MySQLError", + "NUMBER", + "NotSupportedError", + "DBAPISet", + "OperationalError", + "ProgrammingError", + "ROWID", + "STRING", + "TIME", + "TIMESTAMP", + "Warning", + "apilevel", + "connect", + "connections", + "constants", + "converters", + "cursors", + "debug", + "escape", + "escape_string", + "get_client_info", + "paramstyle", + "string_literal", + "threadsafety", + "version_info", +] diff --git a/MySQLdb/_exceptions.py b/MySQLdb/_exceptions.py index 9cfff57..ba35dea 100644 --- a/MySQLdb/_exceptions.py +++ b/MySQLdb/_exceptions.py @@ -5,6 +5,7 @@ These classes are dictated by the DB API v2.0: https://www.python.org/dev/peps/pep-0249/ """ + class MySQLError(Exception): """Exception related to operation with MySQL.""" diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index 1d67daa..8e226ff 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -8,9 +8,16 @@ import re from . import cursors, _mysql from ._exceptions import ( - Warning, Error, InterfaceError, DataError, - DatabaseError, OperationalError, IntegrityError, InternalError, - NotSupportedError, ProgrammingError, + Warning, + Error, + InterfaceError, + DataError, + DatabaseError, + OperationalError, + IntegrityError, + InternalError, + NotSupportedError, + ProgrammingError, ) # Mapping from MySQL charset name to Python codec name @@ -24,6 +31,7 @@ _charset_to_encoding = { re_numeric_part = re.compile(r"^(\d+)") + def numeric_part(s): """Returns the leading numeric part of a string. @@ -136,13 +144,13 @@ class Connection(_mysql.connection): kwargs2 = kwargs.copy() - if 'database' in kwargs2: - kwargs2['db'] = kwargs2.pop('database') - if 'password' in kwargs2: - kwargs2['passwd'] = kwargs2.pop('password') + if "database" in kwargs2: + kwargs2["db"] = kwargs2.pop("database") + if "password" in kwargs2: + kwargs2["passwd"] = kwargs2.pop("password") - if 'conv' in kwargs: - conv = kwargs['conv'] + if "conv" in kwargs: + conv = kwargs["conv"] else: conv = conversions @@ -152,30 +160,31 @@ class Connection(_mysql.connection): conv2[k] = v[:] else: conv2[k] = v - kwargs2['conv'] = conv2 + kwargs2["conv"] = conv2 - cursorclass = kwargs2.pop('cursorclass', self.default_cursor) - charset = kwargs2.get('charset', '') - use_unicode = kwargs2.pop('use_unicode', True) - sql_mode = kwargs2.pop('sql_mode', '') - self._binary_prefix = kwargs2.pop('binary_prefix', False) + cursorclass = kwargs2.pop("cursorclass", self.default_cursor) + charset = kwargs2.get("charset", "") + use_unicode = kwargs2.pop("use_unicode", True) + sql_mode = kwargs2.pop("sql_mode", "") + self._binary_prefix = kwargs2.pop("binary_prefix", False) - client_flag = kwargs.get('client_flag', 0) - client_version = tuple([ numeric_part(n) for n in _mysql.get_client_info().split('.')[:2] ]) + client_flag = kwargs.get("client_flag", 0) + client_version = tuple( + [numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]] + ) if client_version >= (4, 1): client_flag |= CLIENT.MULTI_STATEMENTS if client_version >= (5, 0): client_flag |= CLIENT.MULTI_RESULTS - kwargs2['client_flag'] = client_flag + kwargs2["client_flag"] = client_flag # PEP-249 requires autocommit to be initially off - autocommit = kwargs2.pop('autocommit', False) + autocommit = kwargs2.pop("autocommit", False) - super(Connection, self).__init__(*args, **kwargs2) + super().__init__(*args, **kwargs2) self.cursorclass = cursorclass - self.encoders = dict([ (k, v) for k, v in conv.items() - if type(k) is not int ]) + self.encoders = {k: v for k, v in conv.items() if type(k) is not int} # XXX THIS IS GARBAGE: While this is just a garbage and undocumented, # Django 1.11 depends on it. And they don't fix it because @@ -184,9 +193,11 @@ class Connection(_mysql.connection): # See PyMySQL/mysqlclient-python#306 self.encoders[bytes] = bytes - self._server_version = tuple([ numeric_part(n) for n in self.get_server_info().split('.')[:2] ]) + self._server_version = tuple( + [numeric_part(n) for n in self.get_server_info().split(".")[:2]] + ) - self.encoding = 'ascii' # overridden in set_character_set() + self.encoding = "ascii" # overridden in set_character_set() db = proxy(self) def unicode_literal(u, dummy=None): @@ -200,8 +211,15 @@ class Connection(_mysql.connection): self.set_sql_mode(sql_mode) if use_unicode: - for t in (FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING, FIELD_TYPE.VARCHAR, FIELD_TYPE.TINY_BLOB, - FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.LONG_BLOB, FIELD_TYPE.BLOB): + for t in ( + FIELD_TYPE.STRING, + FIELD_TYPE.VAR_STRING, + FIELD_TYPE.VARCHAR, + FIELD_TYPE.TINY_BLOB, + FIELD_TYPE.MEDIUM_BLOB, + FIELD_TYPE.LONG_BLOB, + FIELD_TYPE.BLOB, + ): self.converter[t] = _bytes_or_str # Unlike other string/blob types, JSON is always text. # MySQL may return JSON with charset==binary. @@ -244,11 +262,11 @@ class Connection(_mysql.connection): assert isinstance(bs, (bytes, bytearray)) x = self.string_literal(bs) # x is escaped and quoted bytes if self._binary_prefix: - return b'_binary' + x + return b"_binary" + x return x def _tuple_literal(self, t): - return b"(%s)" % (b','.join(map(self.literal, t))) + return b"(%s)" % (b",".join(map(self.literal, t))) def literal(self, o): """If o is a single object, returns an SQL literal as a string. @@ -280,7 +298,7 @@ class Connection(_mysql.connection): """ self.query(b"BEGIN") - if not hasattr(_mysql.connection, 'warning_count'): + if not hasattr(_mysql.connection, "warning_count"): def warning_count(self): """Return the number of warnings generated from the @@ -299,11 +317,11 @@ class Connection(_mysql.connection): py_charset = _charset_to_encoding.get(charset, charset) if self.character_set_name() != charset: try: - super(Connection, self).set_character_set(charset) + super().set_character_set(charset) except AttributeError: if self._server_version < (4, 1): raise NotSupportedError("server is too old to set charset") - self.query('SET NAMES %s' % charset) + self.query("SET NAMES %s" % charset) self.store_result() self.encoding = py_charset @@ -320,7 +338,8 @@ class Connection(_mysql.connection): sequence of tuples of (Level, Code, Message). This is only supported in MySQL-4.1 and up. If your server is an earlier version, an empty sequence is returned.""" - if self._server_version < (4,1): return () + if self._server_version < (4, 1): + return () self.query("SHOW WARNINGS") r = self.store_result() warnings = r.fetch_row(0) @@ -337,4 +356,5 @@ class Connection(_mysql.connection): ProgrammingError = ProgrammingError NotSupportedError = NotSupportedError + # vim: colorcolumn=100 diff --git a/MySQLdb/constants/CLIENT.py b/MySQLdb/constants/CLIENT.py index 6559917..35f578c 100644 --- a/MySQLdb/constants/CLIENT.py +++ b/MySQLdb/constants/CLIENT.py @@ -20,10 +20,8 @@ CHANGE_USER = 512 INTERACTIVE = 1024 SSL = 2048 IGNORE_SIGPIPE = 4096 -TRANSACTIONS = 8192 # mysql_com.h was WRONG prior to 3.23.35 +TRANSACTIONS = 8192 # mysql_com.h was WRONG prior to 3.23.35 RESERVED = 16384 SECURE_CONNECTION = 32768 MULTI_STATEMENTS = 65536 MULTI_RESULTS = 131072 - - diff --git a/MySQLdb/constants/CR.py b/MySQLdb/constants/CR.py index 753408e..9d33cf6 100644 --- a/MySQLdb/constants/CR.py +++ b/MySQLdb/constants/CR.py @@ -9,16 +9,18 @@ if __name__ == "__main__": """ Usage: python CR.py [/path/to/mysql/errmsg.h ...] >> CR.py """ - import fileinput, re + import fileinput + import re + data = {} error_last = None for line in fileinput.input(): - line = re.sub(r'/\*.*?\*/', '', line) - m = re.match(r'^\s*#define\s+CR_([A-Z0-9_]+)\s+(\d+)(\s.*|$)', line) + line = re.sub(r"/\*.*?\*/", "", line) + m = re.match(r"^\s*#define\s+CR_([A-Z0-9_]+)\s+(\d+)(\s.*|$)", line) if m: name = m.group(1) value = int(m.group(2)) - if name == 'ERROR_LAST': + if name == "ERROR_LAST": if error_last is None or error_last < value: error_last = value continue @@ -27,9 +29,9 @@ if __name__ == "__main__": data[value].add(name) for value, names in sorted(data.items()): for name in sorted(names): - print('%s = %s' % (name, value)) + print("{} = {}".format(name, value)) if error_last is not None: - print('ERROR_LAST = %s' % error_last) + print("ERROR_LAST = %s" % error_last) ERROR_FIRST = 2000 diff --git a/MySQLdb/constants/ER.py b/MySQLdb/constants/ER.py index 2e623b5..fcd5bf2 100644 --- a/MySQLdb/constants/ER.py +++ b/MySQLdb/constants/ER.py @@ -8,18 +8,20 @@ if __name__ == "__main__": """ Usage: python ER.py [/path/to/mysql/mysqld_error.h ...] >> ER.py """ - import fileinput, re + import fileinput + import re + data = {} error_last = None for line in fileinput.input(): - line = re.sub(r'/\*.*?\*/', '', line) - m = re.match(r'^\s*#define\s+((ER|WARN)_[A-Z0-9_]+)\s+(\d+)\s*', line) + line = re.sub(r"/\*.*?\*/", "", line) + m = re.match(r"^\s*#define\s+((ER|WARN)_[A-Z0-9_]+)\s+(\d+)\s*", line) if m: name = m.group(1) - if name.startswith('ER_'): + if name.startswith("ER_"): name = name[3:] value = int(m.group(3)) - if name == 'ERROR_LAST': + if name == "ERROR_LAST": if error_last is None or error_last < value: error_last = value continue @@ -28,9 +30,9 @@ if __name__ == "__main__": data[value].add(name) for value, names in sorted(data.items()): for name in sorted(names): - print('%s = %s' % (name, value)) + print("{} = {}".format(name, value)) if error_last is not None: - print('ERROR_LAST = %s' % error_last) + print("ERROR_LAST = %s" % error_last) ERROR_FIRST = 1000 diff --git a/MySQLdb/constants/__init__.py b/MySQLdb/constants/__init__.py index 3e774cd..0372265 100644 --- a/MySQLdb/constants/__init__.py +++ b/MySQLdb/constants/__init__.py @@ -1 +1 @@ -__all__ = ['CR', 'FIELD_TYPE','CLIENT','ER','FLAG'] +__all__ = ["CR", "FIELD_TYPE", "CLIENT", "ER", "FLAG"] diff --git a/MySQLdb/converters.py b/MySQLdb/converters.py index c460fbd..33f22f7 100644 --- a/MySQLdb/converters.py +++ b/MySQLdb/converters.py @@ -32,15 +32,24 @@ MySQL.connect(). """ from decimal import Decimal -from MySQLdb._mysql import string_literal, escape +from MySQLdb._mysql import string_literal from MySQLdb.constants import FIELD_TYPE, FLAG -from MySQLdb.times import * +from MySQLdb.times import ( + Date, + DateTimeType, + DateTime2literal, + DateTimeDeltaType, + DateTimeDelta2literal, + DateTime_or_None, + TimeDelta_or_None, + Date_or_None, +) from MySQLdb._exceptions import ProgrammingError -NoneType = type(None) - import array +NoneType = type(None) + try: ArrayType = array.ArrayType except AttributeError: @@ -48,28 +57,33 @@ except AttributeError: def Bool2Str(s, d): - return b'1' if s else b'0' + return b"1" if s else b"0" + def Set2Str(s, d): # Only support ascii string. Not tested. - return string_literal(','.join(s)) + return string_literal(",".join(s)) + def Thing2Str(s, d): """Convert something into a string via str().""" return str(s) + def Float2Str(o, d): s = repr(o) - if s in ('inf', 'nan'): + if s in ("inf", "nan"): raise ProgrammingError("%s can not be used with MySQL" % s) - if 'e' not in s: - s += 'e0' + if "e" not in s: + s += "e0" return s + def None2NULL(o, d): """Convert None to NULL.""" return b"NULL" + def Thing2Literal(o, d): """Convert something into a SQL string literal. If using MySQL-3.23 or newer, string_literal() is a method of the @@ -77,12 +91,15 @@ def Thing2Literal(o, d): that method when the connection is created.""" return string_literal(o) + def Decimal2Literal(o, d): - return format(o, 'f') + return format(o, "f") + def array2Str(o, d): return Thing2Literal(o.tostring(), d) + # bytes or str regarding to BINARY_FLAG. _bytes_or_str = ((FLAG.BINARY, bytes), (None, str)) @@ -97,7 +114,6 @@ conversions = { DateTimeDeltaType: DateTimeDelta2literal, set: Set2Str, Decimal: Decimal2Literal, - FIELD_TYPE.TINY: int, FIELD_TYPE.SHORT: int, FIELD_TYPE.LONG: int, @@ -112,7 +128,6 @@ conversions = { FIELD_TYPE.DATETIME: DateTime_or_None, FIELD_TYPE.TIME: TimeDelta_or_None, FIELD_TYPE.DATE: Date_or_None, - FIELD_TYPE.TINY_BLOB: bytes, FIELD_TYPE.MEDIUM_BLOB: bytes, FIELD_TYPE.LONG_BLOB: bytes, diff --git a/MySQLdb/cursors.py b/MySQLdb/cursors.py index d5ff03b..1d2ee46 100644 --- a/MySQLdb/cursors.py +++ b/MySQLdb/cursors.py @@ -12,13 +12,18 @@ from ._exceptions import ProgrammingError #: executemany only supports simple bulk insert. #: You can use it to load large dataset. RE_INSERT_VALUES = re.compile( - r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" + - r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + - r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", - re.IGNORECASE | re.DOTALL) + "".join( + [ + r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)", + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))", + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", + ] + ), + re.IGNORECASE | re.DOTALL, +) -class BaseCursor(object): +class BaseCursor: """A base for Cursor classes. Useful attributes: description @@ -39,12 +44,20 @@ class BaseCursor(object): #: #: Max size of allowed statement is max_allowed_packet - packet_header_size. #: Default value of max_allowed_packet is 1048576. - max_stmt_length = 64*1024 + max_stmt_length = 64 * 1024 from ._exceptions import ( - MySQLError, Warning, Error, InterfaceError, - DatabaseError, DataError, OperationalError, IntegrityError, - InternalError, ProgrammingError, NotSupportedError, + MySQLError, + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, ) connection = None @@ -98,8 +111,10 @@ class BaseCursor(object): if isinstance(args, (tuple, list)): ret = tuple(literal(ensure_bytes(arg)) for arg in args) elif isinstance(args, dict): - ret = {ensure_bytes(key): literal(ensure_bytes(val)) - for (key, val) in args.items()} + ret = { + ensure_bytes(key): literal(ensure_bytes(val)) + for (key, val) in args.items() + } else: # If it's not a dictionary let's try escaping it anyways. # Worst case it will throw a Value error @@ -216,16 +231,23 @@ class BaseCursor(object): if m: q_prefix = m.group(1) % () q_values = m.group(2).rstrip() - q_postfix = m.group(3) or '' - assert q_values[0] == '(' and q_values[-1] == ')' - return self._do_execute_many(q_prefix, q_values, q_postfix, args, - self.max_stmt_length, - self._get_db().encoding) + q_postfix = m.group(3) or "" + assert q_values[0] == "(" and q_values[-1] == ")" + return self._do_execute_many( + q_prefix, + q_values, + q_postfix, + args, + self.max_stmt_length, + self._get_db().encoding, + ) self.rowcount = sum(self.execute(query, arg) for arg in args) return self.rowcount - def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding): + def _do_execute_many( + self, prefix, values, postfix, args, max_stmt_length, encoding + ): conn = self._get_db() escape = self._escape_args if isinstance(prefix, str): @@ -245,7 +267,7 @@ class BaseCursor(object): rows += self.execute(sql + postfix) sql = bytearray(prefix) else: - sql += b',' + sql += b"," sql += v rows += self.execute(sql + postfix) self.rowcount = rows @@ -283,15 +305,17 @@ class BaseCursor(object): if isinstance(procname, str): procname = procname.encode(db.encoding) if args: - fmt = b'@_' + procname + b'_%d=%s' - q = b'SET %s' % b','.join(fmt % (index, db.literal(arg)) - for index, arg in enumerate(args)) + fmt = b"@_" + procname + b"_%d=%s" + q = b"SET %s" % b",".join( + fmt % (index, db.literal(arg)) for index, arg in enumerate(args) + ) self._query(q) self.nextset() - q = b"CALL %s(%s)" % (procname, - b','.join([b'@_%s_%d' % (procname, i) - for i in range(len(args))])) + q = b"CALL %s(%s)" % ( + procname, + b",".join([b"@_%s_%d" % (procname, i) for i in range(len(args))]), + ) self._query(q) return args @@ -325,7 +349,7 @@ class BaseCursor(object): NotSupportedError = NotSupportedError -class CursorStoreResultMixIn(object): +class CursorStoreResultMixIn: """This is a MixIn class which causes the entire result set to be stored on the client side, i.e. it uses mysql_store_result(). If the result set can be very large, consider adding a LIMIT clause to your @@ -353,7 +377,7 @@ class CursorStoreResultMixIn(object): than size. If size is not defined, cursor.arraysize is used.""" self._check_executed() end = self.rownumber + (size or self.arraysize) - result = self._rows[self.rownumber:end] + result = self._rows[self.rownumber : end] self.rownumber = min(end, len(self._rows)) return result @@ -361,13 +385,13 @@ class CursorStoreResultMixIn(object): """Fetchs all available rows from the cursor.""" self._check_executed() if self.rownumber: - result = self._rows[self.rownumber:] + result = self._rows[self.rownumber :] else: result = self._rows self.rownumber = len(self._rows) return result - def scroll(self, value, mode='relative'): + def scroll(self, value, mode="relative"): """Scroll the cursor in the result set to a new position according to mode. @@ -375,9 +399,9 @@ class CursorStoreResultMixIn(object): the current position in the result set, if set to 'absolute', value states an absolute target position.""" self._check_executed() - if mode == 'relative': + if mode == "relative": r = self.rownumber + value - elif mode == 'absolute': + elif mode == "absolute": r = value else: raise ProgrammingError("unknown scroll mode %s" % repr(mode)) @@ -387,11 +411,11 @@ class CursorStoreResultMixIn(object): def __iter__(self): self._check_executed() - result = self.rownumber and self._rows[self.rownumber:] or self._rows + result = self.rownumber and self._rows[self.rownumber :] or self._rows return iter(result) -class CursorUseResultMixIn(object): +class CursorUseResultMixIn: """This is a MixIn class which causes the result set to be stored in the server and sent row-by-row to client side, i.e. it uses @@ -438,39 +462,35 @@ class CursorUseResultMixIn(object): __next__ = next -class CursorTupleRowsMixIn(object): +class CursorTupleRowsMixIn: """This is a MixIn class that causes all rows to be returned as tuples, which is the standard form required by DB API.""" _fetch_type = 0 -class CursorDictRowsMixIn(object): +class CursorDictRowsMixIn: """This is a MixIn class that causes all rows to be returned as dictionaries. This is a non-standard feature.""" _fetch_type = 1 -class Cursor(CursorStoreResultMixIn, CursorTupleRowsMixIn, - BaseCursor): +class Cursor(CursorStoreResultMixIn, CursorTupleRowsMixIn, BaseCursor): """This is the standard Cursor class that returns rows as tuples and stores the result set in the client.""" -class DictCursor(CursorStoreResultMixIn, CursorDictRowsMixIn, - BaseCursor): +class DictCursor(CursorStoreResultMixIn, CursorDictRowsMixIn, BaseCursor): """This is a Cursor class that returns rows as dictionaries and stores the result set in the client.""" -class SSCursor(CursorUseResultMixIn, CursorTupleRowsMixIn, - BaseCursor): +class SSCursor(CursorUseResultMixIn, CursorTupleRowsMixIn, BaseCursor): """This is a Cursor class that returns rows as tuples and stores the result set in the server.""" -class SSDictCursor(CursorUseResultMixIn, CursorDictRowsMixIn, - BaseCursor): +class SSDictCursor(CursorUseResultMixIn, CursorDictRowsMixIn, BaseCursor): """This is a Cursor class that returns rows as dictionaries and stores the result set in the server.""" diff --git a/MySQLdb/times.py b/MySQLdb/times.py index d47c8fb..f0e9384 100644 --- a/MySQLdb/times.py +++ b/MySQLdb/times.py @@ -16,34 +16,50 @@ Timestamp = datetime DateTimeDeltaType = timedelta DateTimeType = datetime + def DateFromTicks(ticks): """Convert UNIX ticks into a date instance.""" return date(*localtime(ticks)[:3]) + def TimeFromTicks(ticks): """Convert UNIX ticks into a time instance.""" return time(*localtime(ticks)[3:6]) + def TimestampFromTicks(ticks): """Convert UNIX ticks into a datetime instance.""" return datetime(*localtime(ticks)[:6]) + format_TIME = format_DATE = str + def format_TIMEDELTA(v): seconds = int(v.seconds) % 60 minutes = int(v.seconds // 60) % 60 hours = int(v.seconds // 3600) % 24 - return '%d %d:%d:%d' % (v.days, hours, minutes, seconds) + return "%d %d:%d:%d" % (v.days, hours, minutes, seconds) + def format_TIMESTAMP(d): """ :type d: datetime.datetime """ if d.microsecond: - fmt = "{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}" + fmt = " ".join( + [ + "{0.year:04}-{0.month:02}-{0.day:02}", + "{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}", + ] + ) else: - fmt = "{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}" + fmt = " ".join( + [ + "{0.year:04}-{0.month:02}-{0.day:02}", + "{0.hour:02}:{0.minute:02}:{0.second:02}", + ] + ) return fmt.format(d) @@ -64,32 +80,32 @@ def DateTime_or_None(s): return None return datetime( - int(s[:4]), # year - int(s[5:7]), # month - int(s[8:10]), # day + int(s[:4]), # year + int(s[5:7]), # month + int(s[8:10]), # day int(s[11:13] or 0), # hour int(s[14:16] or 0), # minute int(s[17:19] or 0), # second - micros, # microsecond + micros, # microsecond ) except ValueError: return None + def TimeDelta_or_None(s): try: - h, m, s = s.split(':') - if '.' in s: - s, ms = s.split('.') - ms = ms.ljust(6, '0') + h, m, s = s.split(":") + if "." in s: + s, ms = s.split(".") + ms = ms.ljust(6, "0") else: ms = 0 - if h[0] == '-': + if h[0] == "-": negative = True else: negative = False h, m, s, ms = abs(int(h)), int(m), int(s), int(ms) - td = timedelta(hours=h, minutes=m, seconds=s, - microseconds=ms) + td = timedelta(hours=h, minutes=m, seconds=s, microseconds=ms) if negative: return -td else: @@ -98,34 +114,33 @@ def TimeDelta_or_None(s): # unpacking or int/float conversion failed return None + def Time_or_None(s): try: - h, m, s = s.split(':') - if '.' in s: - s, ms = s.split('.') - ms = ms.ljust(6, '0') + h, m, s = s.split(":") + if "." in s: + s, ms = s.split(".") + ms = ms.ljust(6, "0") else: ms = 0 h, m, s, ms = int(h), int(m), int(s), int(ms) - return time(hour=h, minute=m, second=s, - microsecond=ms) + return time(hour=h, minute=m, second=s, microsecond=ms) except ValueError: return None + def Date_or_None(s): try: - return date( - int(s[:4]), # year - int(s[5:7]), # month - int(s[8:10]), # day - ) + return date(int(s[:4]), int(s[5:7]), int(s[8:10]),) # year # month # day except ValueError: return None + def DateTime2literal(d, c): """Format a DateTime object as an ISO timestamp.""" return string_literal(format_TIMESTAMP(d)) + def DateTimeDelta2literal(d, c): """Format a DateTimeDelta object as a time.""" return string_literal(format_TIMEDELTA(d)) diff --git a/ci/test_mysql.py b/ci/test_mysql.py index d24f30f..88a747a 100644 --- a/ci/test_mysql.py +++ b/ci/test_mysql.py @@ -13,33 +13,27 @@ # file for each of the backends you test against. DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.mysql', - 'NAME': 'django_default', - 'USER': 'django', - 'HOST': '127.0.0.1', - 'PASSWORD': 'secret', - 'TEST': { - 'CHARSET': 'utf8mb4', - 'COLLATION': 'utf8mb4_general_ci', - }, + "default": { + "ENGINE": "django.db.backends.mysql", + "NAME": "django_default", + "USER": "django", + "HOST": "127.0.0.1", + "PASSWORD": "secret", + "TEST": {"CHARSET": "utf8mb4", "COLLATION": "utf8mb4_general_ci"}, + }, + "other": { + "ENGINE": "django.db.backends.mysql", + "NAME": "django_other", + "USER": "django", + "HOST": "127.0.0.1", + "PASSWORD": "secret", + "TEST": {"CHARSET": "utf8mb4", "COLLATION": "utf8mb4_general_ci"}, }, - 'other': { - 'ENGINE': 'django.db.backends.mysql', - 'NAME': 'django_other', - 'USER': 'django', - 'HOST': '127.0.0.1', - 'PASSWORD': 'secret', - 'TEST': { - 'CHARSET': 'utf8mb4', - 'COLLATION': 'utf8mb4_general_ci', - }, - } } SECRET_KEY = "django_tests_secret_key" # Use a fast hasher to speed up tests. PASSWORD_HASHERS = [ - 'django.contrib.auth.hashers.MD5PasswordHasher', + "django.contrib.auth.hashers.MD5PasswordHasher", ] diff --git a/doc/conf.py b/doc/conf.py index fc7c089..33f9781 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # MySQLdb documentation build configuration file, created by # sphinx-quickstart on Sun Oct 07 19:36:17 2012. @@ -11,46 +10,49 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys, os +# skip flake8 and black for this file +# flake8: noqa +import sys +import os # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('..')) +#sys.path.insert(0, os.path.abspath("..")) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +#needs_sphinx = "1.0" # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc'] +extensions = ["sphinx.ext.autodoc"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +#source_encoding = "utf-8-sig" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'MySQLdb' -copyright = u'2012, Andy Dustman' +project = "MySQLdb" +copyright = "2012, Andy Dustman" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '1.2' +version = "1.2" # The full version, including alpha/beta/rc tags. -release = '1.2.4b4' +release = "1.2.4b4" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -58,13 +60,13 @@ release = '1.2.4b4' # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +#today = "" # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +#today_fmt = "%B %d, %Y" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. #default_role = None @@ -81,7 +83,7 @@ exclude_patterns = ['_build'] #show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] @@ -91,7 +93,7 @@ pygments_style = 'sphinx' # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = "default" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -120,11 +122,11 @@ html_theme = 'default' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +#html_last_updated_fmt = "%b %d, %Y" # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. @@ -158,33 +160,30 @@ html_static_path = ['_static'] # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +#html_use_opensearch = "" # This is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'MySQLdbdoc' +htmlhelp_basename = "MySQLdbdoc" # -- Options for LaTeX output -------------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'MySQLdb.tex', u'MySQLdb Documentation', - u'Andy Dustman', 'manual'), + ("index", "MySQLdb.tex", "MySQLdb Documentation", "Andy Dustman", "manual"), ] # The name of an image file (relative to this directory) to place at the top of @@ -212,10 +211,7 @@ latex_documents = [ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'mysqldb', u'MySQLdb Documentation', - [u'Andy Dustman'], 1) -] +man_pages = [("index", "mysqldb", "MySQLdb Documentation", ["Andy Dustman"], 1)] # If true, show URL addresses after external links. #man_show_urls = False @@ -227,9 +223,15 @@ man_pages = [ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'MySQLdb', u'MySQLdb Documentation', - u'Andy Dustman', 'MySQLdb', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "MySQLdb", + "MySQLdb Documentation", + "Andy Dustman", + "MySQLdb", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. diff --git a/setup.py b/setup.py index a39e0d1..dfa661c 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ #!/usr/bin/env python import os -import io import setuptools @@ -10,14 +9,14 @@ if os.name == "posix": else: # assume windows from setup_windows import get_config -with io.open('README.md', encoding='utf-8') as f: +with open("README.md", encoding="utf-8") as f: readme = f.read() metadata, options = get_config() -metadata['ext_modules'] = [ - setuptools.Extension("MySQLdb._mysql", sources=['MySQLdb/_mysql.c'], **options) +metadata["ext_modules"] = [ + setuptools.Extension("MySQLdb._mysql", sources=["MySQLdb/_mysql.c"], **options) ] -metadata['long_description'] = readme -metadata['long_description_content_type'] = "text/markdown" -metadata['python_requires'] = '>=3.5' +metadata["long_description"] = readme +metadata["long_description_content_type"] = "text/markdown" +metadata["python_requires"] = ">=3.5" setuptools.setup(**metadata) diff --git a/setup_common.py b/setup_common.py index 2274e3a..28c5182 100644 --- a/setup_common.py +++ b/setup_common.py @@ -1,31 +1,37 @@ from configparser import ConfigParser as SafeConfigParser + def get_metadata_and_options(): config = SafeConfigParser() - config.read(['metadata.cfg', 'site.cfg']) + config.read(["metadata.cfg", "site.cfg"]) - metadata = dict(config.items('metadata')) - options = dict(config.items('options')) + metadata = dict(config.items("metadata")) + options = dict(config.items("options")) - metadata['py_modules'] = list(filter(None, metadata['py_modules'].split('\n'))) - metadata['classifiers'] = list(filter(None, metadata['classifiers'].split('\n'))) + metadata["py_modules"] = list(filter(None, metadata["py_modules"].split("\n"))) + metadata["classifiers"] = list(filter(None, metadata["classifiers"].split("\n"))) return metadata, options + def enabled(options, option): value = options[option] s = value.lower() - if s in ('yes','true','1','y'): + if s in ("yes", "true", "1", "y"): return True - elif s in ('no', 'false', '0', 'n'): + elif s in ("no", "false", "0", "n"): return False else: - raise ValueError("Unknown value %s for option %s" % (value, option)) + raise ValueError("Unknown value {} for option {}".format(value, option)) + def create_release_file(metadata): - with open("MySQLdb/release.py",'w') as rel: - rel.write(""" + with open("MySQLdb/release.py", "w") as rel: + rel.write( + """ __author__ = "%(author)s <%(author_email)s>" version_info = %(version_info)s __version__ = "%(version)s" -""" % metadata) +""" + % metadata + ) diff --git a/setup_posix.py b/setup_posix.py index db82b3c..5602be8 100644 --- a/setup_posix.py +++ b/setup_posix.py @@ -1,126 +1,147 @@ -import os, sys -from configparser import ConfigParser as SafeConfigParser +import os +import sys # This dequote() business is required for some older versions # of mysql_config + def dequote(s): if not s: - raise Exception("Wrong MySQL configuration: maybe https://bugs.mysql.com/bug.php?id=86971 ?") + raise Exception( + "Wrong MySQL configuration: maybe https://bugs.mysql.com/bug.php?id=86971 ?" + ) if s[0] in "\"'" and s[0] == s[-1]: s = s[1:-1] return s + _mysql_config_path = "mysql_config" + def mysql_config(what): from os import popen - f = popen("%s --%s" % (_mysql_config_path, what)) + f = popen("{} --{}".format(_mysql_config_path, what)) data = f.read().strip().split() ret = f.close() if ret: - if ret/256: + if ret / 256: data = [] - if ret/256 > 1: - raise EnvironmentError("%s not found" % (_mysql_config_path,)) + if ret / 256 > 1: + raise OSError("{} not found".format(_mysql_config_path)) return data + def get_config(): from setup_common import get_metadata_and_options, enabled, create_release_file + global _mysql_config_path metadata, options = get_metadata_and_options() - if 'mysql_config' in options: - _mysql_config_path = options['mysql_config'] + if "mysql_config" in options: + _mysql_config_path = options["mysql_config"] else: try: - mysql_config('version') - except EnvironmentError: + mysql_config("version") + except OSError: # try mariadb_config _mysql_config_path = "mariadb_config" try: - mysql_config('version') - except EnvironmentError: + mysql_config("version") + except OSError: _mysql_config_path = "mysql_config" extra_objects = [] - static = enabled(options, 'static') + static = enabled(options, "static") # allow a command-line option to override the base config file to permit # a static build to be created via requirements.txt # - if '--static' in sys.argv: + if "--static" in sys.argv: static = True - sys.argv.remove('--static') + sys.argv.remove("--static") libs = mysql_config("libs") - library_dirs = [dequote(i[2:]) for i in libs if i.startswith('-L')] - libraries = [dequote(i[2:]) for i in libs if i.startswith('-l')] - extra_link_args = [x for x in libs if not x.startswith(('-l', '-L'))] + library_dirs = [dequote(i[2:]) for i in libs if i.startswith("-L")] + libraries = [dequote(i[2:]) for i in libs if i.startswith("-l")] + extra_link_args = [x for x in libs if not x.startswith(("-l", "-L"))] - removable_compile_args = ('-I', '-L', '-l') - extra_compile_args = [i.replace("%", "%%") for i in mysql_config("cflags") - if i[:2] not in removable_compile_args] + removable_compile_args = ("-I", "-L", "-l") + extra_compile_args = [ + i.replace("%", "%%") + for i in mysql_config("cflags") + if i[:2] not in removable_compile_args + ] # Copy the arch flags for linking as well for i in range(len(extra_compile_args)): - if extra_compile_args[i] == '-arch': - extra_link_args += ['-arch', extra_compile_args[i + 1]] + if extra_compile_args[i] == "-arch": + extra_link_args += ["-arch", extra_compile_args[i + 1]] - include_dirs = [dequote(i[2:]) - for i in mysql_config('include') if i.startswith('-I')] + include_dirs = [ + dequote(i[2:]) for i in mysql_config("include") if i.startswith("-I") + ] if static: # properly handle mysql client libraries that are not called libmysqlclient client = None - CLIENT_LIST = ['mysqlclient', 'mysqlclient_r', 'mysqld', 'mariadb', - 'mariadbclient', 'perconaserverclient', 'perconaserverclient_r'] + CLIENT_LIST = [ + "mysqlclient", + "mysqlclient_r", + "mysqld", + "mariadb", + "mariadbclient", + "perconaserverclient", + "perconaserverclient_r", + ] for c in CLIENT_LIST: if c in libraries: client = c break - if client == 'mariadb': - client = 'mariadbclient' + if client == "mariadb": + client = "mariadbclient" if client is None: raise ValueError("Couldn't identify mysql client library") - extra_objects.append(os.path.join(library_dirs[0], 'lib%s.a' % client)) + extra_objects.append(os.path.join(library_dirs[0], "lib%s.a" % client)) if client in libraries: libraries.remove(client) else: # mysql_config may have "-lmysqlclient -lz -lssl -lcrypto", but zlib and # ssl is not used by _mysql. They are needed only for static build. - for L in ('crypto', 'ssl', 'z'): + for L in ("crypto", "ssl", "z"): if L in libraries: libraries.remove(L) name = "mysqlclient" - metadata['name'] = name + metadata["name"] = name define_macros = [ - ('version_info', metadata['version_info']), - ('__version__', metadata['version']), - ] + ("version_info", metadata["version_info"]), + ("__version__", metadata["version"]), + ] create_release_file(metadata) - del metadata['version_info'] + del metadata["version_info"] ext_options = dict( - library_dirs = library_dirs, - libraries = libraries, - extra_compile_args = extra_compile_args, - extra_link_args = extra_link_args, - include_dirs = include_dirs, - extra_objects = extra_objects, - define_macros = define_macros, + library_dirs=library_dirs, + libraries=libraries, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + include_dirs=include_dirs, + extra_objects=extra_objects, + define_macros=define_macros, ) # newer versions of gcc require libstdc++ if doing a static build if static: - ext_options['language'] = 'c++' + ext_options["language"] = "c++" return metadata, ext_options + if __name__ == "__main__": - sys.stderr.write("""You shouldn't be running this directly; it is used by setup.py.""") + sys.stderr.write( + """You shouldn't be running this directly; it is used by setup.py.""" + ) diff --git a/setup_windows.py b/setup_windows.py index cb2cbab..917eb49 100644 --- a/setup_windows.py +++ b/setup_windows.py @@ -1,9 +1,10 @@ -import os, sys +import os +import sys from distutils.msvccompiler import get_build_version def get_config(): - from setup_common import get_metadata_and_options, enabled, create_release_file + from setup_common import get_metadata_and_options, create_release_file metadata, options = get_metadata_and_options() @@ -16,37 +17,42 @@ def get_config(): vcversion = int(get_build_version()) if client == "mariadbclient": - library_dirs = [os.path.join(connector, 'lib', 'mariadb')] - libraries = ['kernel32', 'advapi32', 'wsock32', 'shlwapi', 'Ws2_32', client ] - include_dirs = [os.path.join(connector, 'include', 'mariadb')] + library_dirs = [os.path.join(connector, "lib", "mariadb")] + libraries = ["kernel32", "advapi32", "wsock32", "shlwapi", "Ws2_32", client] + include_dirs = [os.path.join(connector, "include", "mariadb")] else: - library_dirs = [os.path.join(connector, r'lib\vs%d' % vcversion), - os.path.join(connector, "lib")] - libraries = ['kernel32', 'advapi32', 'wsock32', client ] - include_dirs = [os.path.join(connector, r'include')] + library_dirs = [ + os.path.join(connector, r"lib\vs%d" % vcversion), + os.path.join(connector, "lib"), + ] + libraries = ["kernel32", "advapi32", "wsock32", client] + include_dirs = [os.path.join(connector, r"include")] - extra_compile_args = ['/Zl', '/D_CRT_SECURE_NO_WARNINGS' ] - extra_link_args = ['/MANIFEST'] + extra_compile_args = ["/Zl", "/D_CRT_SECURE_NO_WARNINGS"] + extra_link_args = ["/MANIFEST"] name = "mysqlclient" - metadata['name'] = name + metadata["name"] = name define_macros = [ - ('version_info', metadata['version_info']), - ('__version__', metadata['version']), - ] + ("version_info", metadata["version_info"]), + ("__version__", metadata["version"]), + ] create_release_file(metadata) - del metadata['version_info'] + del metadata["version_info"] ext_options = dict( - library_dirs = library_dirs, - libraries = libraries, - extra_compile_args = extra_compile_args, - extra_link_args = extra_link_args, - include_dirs = include_dirs, - extra_objects = extra_objects, - define_macros = define_macros, + library_dirs=library_dirs, + libraries=libraries, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + include_dirs=include_dirs, + extra_objects=extra_objects, + define_macros=define_macros, ) return metadata, ext_options + if __name__ == "__main__": - sys.stderr.write("""You shouldn't be running this directly; it is used by setup.py.""") + sys.stderr.write( + """You shouldn't be running this directly; it is used by setup.py.""" + ) diff --git a/tests/capabilities.py b/tests/capabilities.py index 15db533..cafe1e6 100644 --- a/tests/capabilities.py +++ b/tests/capabilities.py @@ -6,7 +6,6 @@ """ from time import time -import array import unittest from configdb import connection_factory @@ -16,35 +15,42 @@ class DatabaseTest(unittest.TestCase): db_module = None connect_args = () connect_kwargs = dict() - create_table_extra = '' + create_table_extra = "" rows = 10 debug = False def setUp(self): - import gc + db = connection_factory(**self.connect_kwargs) self.connection = db self.cursor = db.cursor() - self.BLOBUText = u''.join([chr(i) for i in range(16384)]) - self.BLOBBinary = self.db_module.Binary((u''.join([chr(i) for i in range(256)] * 16)).encode('latin1')) + self.BLOBUText = "".join([chr(i) for i in range(16384)]) + self.BLOBBinary = self.db_module.Binary( + ("".join([chr(i) for i in range(256)] * 16)).encode("latin1") + ) leak_test = True def tearDown(self): if self.leak_test: import gc + del self.cursor orphans = gc.collect() - self.failIf(orphans, "%d orphaned objects found after deleting cursor" % orphans) + self.failIf( + orphans, "%d orphaned objects found after deleting cursor" % orphans + ) del self.connection orphans = gc.collect() - self.failIf(orphans, "%d orphaned objects found after deleting connection" % orphans) + self.failIf( + orphans, "%d orphaned objects found after deleting connection" % orphans + ) def table_exists(self, name): try: - self.cursor.execute('select * from %s where 1=0' % name) - except: + self.cursor.execute("select * from %s where 1=0" % name) + except Exception: return False else: return True @@ -55,7 +61,7 @@ class DatabaseTest(unittest.TestCase): def new_table_name(self): i = id(self.cursor) while True: - name = self.quote_identifier('tb%08x' % i) + name = self.quote_identifier("tb%08x" % i) if not self.table_exists(name): return name i = i + 1 @@ -71,82 +77,95 @@ class DatabaseTest(unittest.TestCase): """ self.table = self.new_table_name() - self.cursor.execute('CREATE TABLE %s (%s) %s' % - (self.table, - ',\n'.join(columndefs), - self.create_table_extra)) + self.cursor.execute( + "CREATE TABLE %s (%s) %s" + % (self.table, ",\n".join(columndefs), self.create_table_extra) + ) def check_data_integrity(self, columndefs, generator): # insert self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) - data = [ [ generator(i,j) for j in range(len(columndefs)) ] - for i in range(self.rows) ] + insert_statement = "INSERT INTO %s VALUES (%s)" % ( + self.table, + ",".join(["%s"] * len(columndefs)), + ) + data = [ + [generator(i, j) for j in range(len(columndefs))] for i in range(self.rows) + ] self.cursor.executemany(insert_statement, data) self.connection.commit() # verify - self.cursor.execute('select * from %s' % self.table) - l = self.cursor.fetchall() - self.assertEqual(len(l), self.rows) + self.cursor.execute("select * from %s" % self.table) + res = self.cursor.fetchall() + self.assertEqual(len(res), self.rows) try: for i in range(self.rows): for j in range(len(columndefs)): - self.assertEqual(l[i][j], generator(i,j)) + self.assertEqual(res[i][j], generator(i, j)) finally: if not self.debug: - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute("drop table %s" % (self.table)) def test_transactions(self): - columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') + columndefs = ("col1 INT", "col2 VARCHAR(255)") + def generator(row, col): - if col == 0: return row - else: return ('%i' % (row%10))*255 + if col == 0: + return row + else: + return ("%i" % (row % 10)) * 255 + self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) - data = [ [ generator(i,j) for j in range(len(columndefs)) ] - for i in range(self.rows) ] + insert_statement = "INSERT INTO %s VALUES (%s)" % ( + self.table, + ",".join(["%s"] * len(columndefs)), + ) + data = [ + [generator(i, j) for j in range(len(columndefs))] for i in range(self.rows) + ] self.cursor.executemany(insert_statement, data) # verify self.connection.commit() - self.cursor.execute('select * from %s' % self.table) - l = self.cursor.fetchall() - self.assertEqual(len(l), self.rows) + self.cursor.execute("select * from %s" % self.table) + res = self.cursor.fetchall() + self.assertEqual(len(res), self.rows) for i in range(self.rows): for j in range(len(columndefs)): - self.assertEqual(l[i][j], generator(i,j)) - delete_statement = 'delete from %s where col1=%%s' % self.table + self.assertEqual(res[i][j], generator(i, j)) + delete_statement = "delete from %s where col1=%%s" % self.table self.cursor.execute(delete_statement, (0,)) - self.cursor.execute('select col1 from %s where col1=%s' % \ - (self.table, 0)) - l = self.cursor.fetchall() - self.assertFalse(l, "DELETE didn't work") + self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0)) + res = self.cursor.fetchall() + self.assertFalse(res, "DELETE didn't work") self.connection.rollback() - self.cursor.execute('select col1 from %s where col1=%s' % \ - (self.table, 0)) - l = self.cursor.fetchall() - self.assertTrue(len(l) == 1, "ROLLBACK didn't work") - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0)) + res = self.cursor.fetchall() + self.assertTrue(len(res) == 1, "ROLLBACK didn't work") + self.cursor.execute("drop table %s" % (self.table)) def test_truncation(self): - columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') + columndefs = ("col1 INT", "col2 VARCHAR(255)") + def generator(row, col): - if col == 0: return row - else: return ('%i' % (row%10))*((255-self.rows//2)+row) + if col == 0: + return row + else: + return ("%i" % (row % 10)) * ((255 - self.rows // 2) + row) + self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) + insert_statement = "INSERT INTO %s VALUES (%s)" % ( + self.table, + ",".join(["%s"] * len(columndefs)), + ) try: - self.cursor.execute(insert_statement, (0, '0'*256)) + self.cursor.execute(insert_statement, (0, "0" * 256)) except self.connection.DataError: pass else: - self.fail("Over-long column did not generate warnings/exception with single insert") + self.fail( + "Over-long column did not generate warnings/exception with single insert" # noqa: E501 + ) self.connection.rollback() @@ -154,143 +173,145 @@ class DatabaseTest(unittest.TestCase): for i in range(self.rows): data = [] for j in range(len(columndefs)): - data.append(generator(i,j)) - self.cursor.execute(insert_statement,tuple(data)) + data.append(generator(i, j)) + self.cursor.execute(insert_statement, tuple(data)) except self.connection.DataError: pass else: - self.fail("Over-long columns did not generate warnings/exception with execute()") + self.fail( + "Over-long columns did not generate warnings/exception with execute()" # noqa: E501 + ) self.connection.rollback() try: - data = [ [ generator(i,j) for j in range(len(columndefs)) ] - for i in range(self.rows) ] + data = [ + [generator(i, j) for j in range(len(columndefs))] + for i in range(self.rows) + ] self.cursor.executemany(insert_statement, data) except self.connection.DataError: pass else: - self.fail("Over-long columns did not generate warnings/exception with executemany()") + self.fail( + "Over-long columns did not generate warnings/exception with executemany()" # noqa: E501 + ) self.connection.rollback() - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute("drop table %s" % (self.table)) def test_CHAR(self): # Character data - def generator(row,col): - return ('%i' % ((row+col) % 10)) * 255 - self.check_data_integrity( - ('col1 char(255)','col2 char(255)'), - generator) + def generator(row, col): + return ("%i" % ((row + col) % 10)) * 255 + + self.check_data_integrity(("col1 char(255)", "col2 char(255)"), generator) def test_INT(self): # Number data - def generator(row,col): - return row*row - self.check_data_integrity( - ('col1 INT',), - generator) + def generator(row, col): + return row * row + + self.check_data_integrity(("col1 INT",), generator) def test_DECIMAL(self): # DECIMAL from decimal import Decimal - def generator(row,col): - return Decimal("%d.%02d" % (row, col)) - self.check_data_integrity( - ('col1 DECIMAL(5,2)',), - generator) - val = Decimal('1.11111111111111119E-7') - self.cursor.execute('SELECT %s', (val,)) + def generator(row, col): + return Decimal("%d.%02d" % (row, col)) + + self.check_data_integrity(("col1 DECIMAL(5,2)",), generator) + + val = Decimal("1.11111111111111119E-7") + self.cursor.execute("SELECT %s", (val,)) result = self.cursor.fetchone()[0] self.assertEqual(result, val) self.assertIsInstance(result, Decimal) - self.cursor.execute('SELECT %s + %s', (Decimal('0.1'), Decimal('0.2'))) + self.cursor.execute("SELECT %s + %s", (Decimal("0.1"), Decimal("0.2"))) result = self.cursor.fetchone()[0] - self.assertEqual(result, Decimal('0.3')) + self.assertEqual(result, Decimal("0.3")) self.assertIsInstance(result, Decimal) def test_DATE(self): ticks = time() - def generator(row,col): - return self.db_module.DateFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 DATE',), - generator) + + def generator(row, col): + return self.db_module.DateFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 DATE",), generator) def test_TIME(self): ticks = time() - def generator(row,col): - return self.db_module.TimeFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 TIME',), - generator) + + def generator(row, col): + return self.db_module.TimeFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 TIME",), generator) def test_DATETIME(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 DATETIME',), - generator) + + def generator(row, col): + return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 DATETIME",), generator) def test_TIMESTAMP(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 TIMESTAMP',), - generator) + + def generator(row, col): + return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 TIMESTAMP",), generator) def test_fractional_TIMESTAMP(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0) - self.check_data_integrity( - ('col1 TIMESTAMP',), - generator) + + def generator(row, col): + return self.db_module.TimestampFromTicks( + ticks + row * 86400 - col * 1313 + row * 0.7 * col / 3.0 + ) + + self.check_data_integrity(("col1 TIMESTAMP",), generator) def test_LONG(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBUText # 'BLOB Text ' * 1024 - self.check_data_integrity( - ('col1 INT','col2 LONG'), - generator) + return self.BLOBUText # 'BLOB Text ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 LONG"), generator) def test_TEXT(self): - def generator(row,col): - return self.BLOBUText # 'BLOB Text ' * 1024 - self.check_data_integrity( - ('col2 TEXT',), - generator) + def generator(row, col): + return self.BLOBUText # 'BLOB Text ' * 1024 + + self.check_data_integrity(("col2 TEXT",), generator) def test_LONG_BYTE(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBBinary # 'BLOB\000Binary ' * 1024 - self.check_data_integrity( - ('col1 INT','col2 LONG BYTE'), - generator) + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 LONG BYTE"), generator) def test_BLOB(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBBinary # 'BLOB\000Binary ' * 1024 - self.check_data_integrity( - ('col1 INT','col2 BLOB'), - generator) + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 BLOB"), generator) def test_DOUBLE(self): for val in (18014398509481982.0, 0.1): - self.cursor.execute('SELECT %s', (val,)); + self.cursor.execute("SELECT %s", (val,)) result = self.cursor.fetchone()[0] self.assertEqual(result, val) self.assertIsInstance(result, float) diff --git a/tests/configdb.py b/tests/configdb.py index 307cc3f..f3a56e2 100644 --- a/tests/configdb.py +++ b/tests/configdb.py @@ -3,12 +3,9 @@ from os import environ, path tests_path = path.dirname(__file__) -conf_file = environ.get('TESTDB', 'default.cnf') +conf_file = environ.get("TESTDB", "default.cnf") conf_path = path.join(tests_path, conf_file) -connect_kwargs = dict( - read_default_file = conf_path, - read_default_group = "MySQLdb-tests", -) +connect_kwargs = dict(read_default_file=conf_path, read_default_group="MySQLdb-tests",) def connection_kwargs(kwargs): @@ -19,6 +16,7 @@ def connection_kwargs(kwargs): def connection_factory(**kwargs): import MySQLdb + db_kwargs = connection_kwargs(kwargs) db = MySQLdb.connect(**db_kwargs) return db diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 79c188a..0ca8bce 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -''' Python DB API 2.0 driver compliance unit test suite. +""" Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. @@ -9,11 +9,11 @@ this is turning out to be a thoroughly unwholesome unit test." -- Ian Bicking -''' +""" -__rcs_id__ = '$Id$' -__version__ = '$Revision$'[11:-2] -__author__ = 'Stuart Bishop ' +__rcs_id__ = "$Id$" +__version__ = "$Revision$"[11:-2] +__author__ = "Stuart Bishop " import unittest import time @@ -64,8 +64,9 @@ import time # - Fix bugs in test_setoutputsize_basic and test_setinputsizes # + class DatabaseAPI20Test(unittest.TestCase): - ''' Test a database self.driver for DB API 2.0 compatibility. + """ Test a database self.driver for DB API 2.0 compatibility. This implementation tests Gadfly, but the TestCase is structured so that other self.drivers can subclass this test case to ensure compiliance with the DB-API. It is @@ -84,45 +85,45 @@ class DatabaseAPI20Test(unittest.TestCase): Don't 'import DatabaseAPI20Test from dbapi20', or you will confuse the unit tester - just 'import dbapi20'. - ''' + """ # The self.driver module. This should be the module where the 'connect' # method is to be found driver = None - connect_args = () # List of arguments to pass to connect - connect_kw_args = {} # Keyword arguments for connect - table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + connect_args = () # List of arguments to pass to connect + connect_kw_args = {} # Keyword arguments for connect + table_prefix = "dbapi20test_" # If you need to specify a prefix for tables - ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix - ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix - xddl1 = 'drop table %sbooze' % table_prefix - xddl2 = 'drop table %sbarflys' % table_prefix + ddl1 = "create table %sbooze (name varchar(20))" % table_prefix + ddl2 = "create table %sbarflys (name varchar(20))" % table_prefix + xddl1 = "drop table %sbooze" % table_prefix + xddl2 = "drop table %sbarflys" % table_prefix - lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + lowerfunc = "lower" # Name of stored procedure to convert string->lowercase # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. - def executeDDL1(self,cursor): + def executeDDL1(self, cursor): cursor.execute(self.ddl1) - def executeDDL2(self,cursor): + def executeDDL2(self, cursor): cursor.execute(self.ddl2) def setUp(self): - ''' self.drivers should override this method to perform required setup + """ self.drivers should override this method to perform required setup if any is necessary, such as creating the database. - ''' + """ pass def tearDown(self): - ''' self.drivers should override this method to perform required cleanup + """ self.drivers should override this method to perform required cleanup if any is necessary, such as deleting the test database. The default drops the tables that may be created. - ''' + """ con = self._connect() try: cur = con.cursor() - for ddl in (self.xddl1,self.xddl2): + for ddl in (self.xddl1, self.xddl2): try: cur.execute(ddl) con.commit() @@ -135,9 +136,7 @@ class DatabaseAPI20Test(unittest.TestCase): def _connect(self): try: - return self.driver.connect( - *self.connect_args,**self.connect_kw_args - ) + return self.driver.connect(*self.connect_args, **self.connect_kw_args) except AttributeError: self.fail("No connect method found in self.driver module") @@ -150,7 +149,7 @@ class DatabaseAPI20Test(unittest.TestCase): # Must exist apilevel = self.driver.apilevel # Must equal 2.0 - self.assertEqual(apilevel,'2.0') + self.assertEqual(apilevel, "2.0") except AttributeError: self.fail("Driver doesn't define apilevel") @@ -159,7 +158,7 @@ class DatabaseAPI20Test(unittest.TestCase): # Must exist threadsafety = self.driver.threadsafety # Must be a valid value - self.assertTrue(threadsafety in (0,1,2,3)) + self.assertTrue(threadsafety in (0, 1, 2, 3)) except AttributeError: self.fail("Driver doesn't define threadsafety") @@ -168,38 +167,24 @@ class DatabaseAPI20Test(unittest.TestCase): # Must exist paramstyle = self.driver.paramstyle # Must be a valid value - self.assertTrue(paramstyle in ( - 'qmark','numeric','named','format','pyformat' - )) + self.assertTrue( + paramstyle in ("qmark", "numeric", "named", "format", "pyformat") + ) except AttributeError: self.fail("Driver doesn't define paramstyle") def test_Exceptions(self): # Make sure required exceptions exist, and are in the # defined hierarchy. - self.assertTrue(issubclass(self.driver.Warning,Exception)) - self.assertTrue(issubclass(self.driver.Error,Exception)) - self.assertTrue( - issubclass(self.driver.InterfaceError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.DatabaseError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.OperationalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.IntegrityError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.InternalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.ProgrammingError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.NotSupportedError,self.driver.Error) - ) + self.assertTrue(issubclass(self.driver.Warning, Exception)) + self.assertTrue(issubclass(self.driver.Error, Exception)) + self.assertTrue(issubclass(self.driver.InterfaceError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.DatabaseError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.OperationalError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.IntegrityError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.InternalError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.ProgrammingError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.NotSupportedError, self.driver.Error)) def test_ExceptionsAsConnectionAttributes(self): # OPTIONAL EXTENSION @@ -220,7 +205,6 @@ class DatabaseAPI20Test(unittest.TestCase): self.assertTrue(con.ProgrammingError is drv.ProgrammingError) self.assertTrue(con.NotSupportedError is drv.NotSupportedError) - def test_commit(self): con = self._connect() try: @@ -233,7 +217,7 @@ class DatabaseAPI20Test(unittest.TestCase): con = self._connect() # If rollback is defined, it should either work or throw # the documented exception - if hasattr(con,'rollback'): + if hasattr(con, "rollback"): try: con.rollback() except self.driver.NotSupportedError: @@ -242,7 +226,7 @@ class DatabaseAPI20Test(unittest.TestCase): def test_cursor(self): con = self._connect() try: - cur = con.cursor() + _ = con.cursor() finally: con.close() @@ -254,14 +238,14 @@ class DatabaseAPI20Test(unittest.TestCase): cur1 = con.cursor() cur2 = con.cursor() self.executeDDL1(cur1) - cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) + cur1.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) cur2.execute("select name from %sbooze" % self.table_prefix) booze = cur2.fetchall() - self.assertEqual(len(booze),1) - self.assertEqual(len(booze[0]),1) - self.assertEqual(booze[0][0],'Victoria Bitter') + self.assertEqual(len(booze), 1) + self.assertEqual(len(booze[0]), 1) + self.assertEqual(booze[0][0], "Victoria Bitter") finally: con.close() @@ -270,31 +254,41 @@ class DatabaseAPI20Test(unittest.TestCase): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.description,None, - 'cursor.description should be none after executing a ' - 'statement that can return no rows (such as DDL)' - ) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(len(cur.description),1, - 'cursor.description describes too many columns' - ) - self.assertEqual(len(cur.description[0]),7, - 'cursor.description[x] tuples must have 7 elements' - ) - self.assertEqual(cur.description[0][0].lower(),'name', - 'cursor.description[x][0] must return column name' - ) - self.assertEqual(cur.description[0][1],self.driver.STRING, - 'cursor.description[x][1] must return column type. Got %r' - % cur.description[0][1] - ) + self.assertEqual( + cur.description, + None, + "cursor.description should be none after executing a " + "statement that can return no rows (such as DDL)", + ) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + len(cur.description), 1, "cursor.description describes too many columns" + ) + self.assertEqual( + len(cur.description[0]), + 7, + "cursor.description[x] tuples must have 7 elements", + ) + self.assertEqual( + cur.description[0][0].lower(), + "name", + "cursor.description[x][0] must return column name", + ) + self.assertEqual( + cur.description[0][1], + self.driver.STRING, + "cursor.description[x][1] must return column type. Got %r" + % cur.description[0][1], + ) # Make sure self.description gets reset self.executeDDL2(cur) - self.assertEqual(cur.description,None, - 'cursor.description not being set to None when executing ' - 'no-result statements (eg. DDL)' - ) + self.assertEqual( + cur.description, + None, + "cursor.description not being set to None when executing " + "no-result statements (eg. DDL)", + ) finally: con.close() @@ -303,47 +297,49 @@ class DatabaseAPI20Test(unittest.TestCase): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount should be -1 after executing no-result ' - 'statements' - ) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number or rows inserted, or ' - 'set to -1 after executing an insert statement' - ) + self.assertEqual( + cur.rowcount, + -1, + "cursor.rowcount should be -1 after executing no-result " "statements", + ) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number or rows inserted, or " + "set to -1 after executing an insert statement", + ) cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number of rows returned, or " + "set to -1 after executing a select statement", + ) self.executeDDL2(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount not being reset to -1 after executing ' - 'no-result statements' - ) + self.assertEqual( + cur.rowcount, + -1, + "cursor.rowcount not being reset to -1 after executing " + "no-result statements", + ) finally: con.close() - lower_func = 'lower' + lower_func = "lower" + def test_callproc(self): con = self._connect() try: cur = con.cursor() - if self.lower_func and hasattr(cur,'callproc'): - r = cur.callproc(self.lower_func,('FOO',)) - self.assertEqual(len(r),1) - self.assertEqual(r[0],'FOO') + if self.lower_func and hasattr(cur, "callproc"): + r = cur.callproc(self.lower_func, ("FOO",)) + self.assertEqual(len(r), 1) + self.assertEqual(r[0], "FOO") r = cur.fetchall() - self.assertEqual(len(r),1,'callproc produced no result set') - self.assertEqual(len(r[0]),1, - 'callproc produced invalid result set' - ) - self.assertEqual(r[0][0],'foo', - 'callproc produced invalid results' - ) + self.assertEqual(len(r), 1, "callproc produced no result set") + self.assertEqual(len(r[0]), 1, "callproc produced invalid result set") + self.assertEqual(r[0][0], "foo", "callproc produced invalid results") finally: con.close() @@ -356,14 +352,14 @@ class DatabaseAPI20Test(unittest.TestCase): # cursor.execute should raise an Error if called after connection # closed - self.assertRaises(self.driver.Error,self.executeDDL1,cur) + self.assertRaises(self.driver.Error, self.executeDDL1, cur) # connection.commit should raise an Error if called after connection' # closed.' - self.assertRaises(self.driver.Error,con.commit) + self.assertRaises(self.driver.Error, con.commit) # connection.close should raise an Error if called more than once - self.assertRaises(self.driver.Error,con.close) + self.assertRaises(self.driver.Error, con.close) def test_execute(self): con = self._connect() @@ -373,105 +369,99 @@ class DatabaseAPI20Test(unittest.TestCase): finally: con.close() - def _paraminsert(self,cur): + def _paraminsert(self, cur): self.executeDDL1(cur) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1)) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertTrue(cur.rowcount in (-1, 1)) - if self.driver.paramstyle == 'qmark': + if self.driver.paramstyle == "qmark": cur.execute( - 'insert into %sbooze values (?)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'numeric': + "insert into %sbooze values (?)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "numeric": cur.execute( - 'insert into %sbooze values (:1)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'named': + "insert into %sbooze values (:1)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "named": cur.execute( - 'insert into %sbooze values (:beer)' % self.table_prefix, - {'beer':"Cooper's"} - ) - elif self.driver.paramstyle == 'format': + "insert into %sbooze values (:beer)" % self.table_prefix, + {"beer": "Cooper's"}, + ) + elif self.driver.paramstyle == "format": cur.execute( - 'insert into %sbooze values (%%s)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'pyformat': + "insert into %sbooze values (%%s)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "pyformat": cur.execute( - 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, - {'beer':"Cooper's"} - ) + "insert into %sbooze values (%%(beer)s)" % self.table_prefix, + {"beer": "Cooper's"}, + ) else: - self.fail('Invalid paramstyle') - self.assertTrue(cur.rowcount in (-1,1)) + self.fail("Invalid paramstyle") + self.assertTrue(cur.rowcount in (-1, 1)) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') - beers = [res[0][0],res[1][0]] + self.assertEqual(len(res), 2, "cursor.fetchall returned too few rows") + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Cooper's", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) - self.assertEqual(beers[1],"Victoria Bitter", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) + self.assertEqual( + beers[0], + "Cooper's", + "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly", + ) + self.assertEqual( + beers[1], + "Victoria Bitter", + "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly", + ) def test_executemany(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) - largs = [ ("Cooper's",) , ("Boag's",) ] - margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] - if self.driver.paramstyle == 'qmark': + largs = [("Cooper's",), ("Boag's",)] + margs = [{"beer": "Cooper's"}, {"beer": "Boag's"}] + if self.driver.paramstyle == "qmark": cur.executemany( - 'insert into %sbooze values (?)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'numeric': + "insert into %sbooze values (?)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "numeric": cur.executemany( - 'insert into %sbooze values (:1)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'named': + "insert into %sbooze values (:1)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "named": cur.executemany( - 'insert into %sbooze values (:beer)' % self.table_prefix, - margs - ) - elif self.driver.paramstyle == 'format': + "insert into %sbooze values (:beer)" % self.table_prefix, margs + ) + elif self.driver.paramstyle == "format": cur.executemany( - 'insert into %sbooze values (%%s)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'pyformat': + "insert into %sbooze values (%%s)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "pyformat": cur.executemany( - 'insert into %sbooze values (%%(beer)s)' % ( - self.table_prefix - ), - margs - ) + "insert into %sbooze values (%%(beer)s)" % (self.table_prefix), + margs, + ) else: - self.fail('Unknown paramstyle') - self.assertTrue(cur.rowcount in (-1,2), - 'insert using cursor.executemany set cursor.rowcount to ' - 'incorrect value %r' % cur.rowcount - ) - cur.execute('select name from %sbooze' % self.table_prefix) + self.fail("Unknown paramstyle") + self.assertTrue( + cur.rowcount in (-1, 2), + "insert using cursor.executemany set cursor.rowcount to " + "incorrect value %r" % cur.rowcount, + ) + cur.execute("select name from %sbooze" % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res),2, - 'cursor.fetchall retrieved incorrect number of rows' - ) - beers = [res[0][0],res[1][0]] + self.assertEqual( + len(res), 2, "cursor.fetchall retrieved incorrect number of rows" + ) + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') - self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + self.assertEqual(beers[0], "Boag's", "incorrect data retrieved") + self.assertEqual(beers[1], "Cooper's", "incorrect data retrieved") finally: con.close() @@ -482,59 +472,62 @@ class DatabaseAPI20Test(unittest.TestCase): # cursor.fetchone should raise an Error if called before # executing a select-type query - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after # executing a query that cannot return rows self.executeDDL1(cur) - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if a query retrieves " "no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) # cursor.fetchone should raise an Error if called after # executing a query that cannot return rows - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertRaises(self.driver.Error,cur.fetchone) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchone() - self.assertEqual(len(r),1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0],'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if no more rows available' - ) - self.assertTrue(cur.rowcount in (-1,1)) + self.assertEqual( + len(r), 1, "cursor.fetchone should have retrieved a single row" + ) + self.assertEqual( + r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data" + ) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if no more rows available", + ) + self.assertTrue(cur.rowcount in (-1, 1)) finally: con.close() samples = [ - 'Carlton Cold', - 'Carlton Draft', - 'Mountain Goat', - 'Redback', - 'Victoria Bitter', - 'XXXX' - ] + "Carlton Cold", + "Carlton Draft", + "Mountain Goat", + "Redback", + "Victoria Bitter", + "XXXX", + ] def _populate(self): - ''' Return a list of sql commands to setup the DB for the fetch + """ Return a list of sql commands to setup the DB for the fetch tests. - ''' + """ populate = [ - "insert into %sbooze values ('%s')" % (self.table_prefix,s) - for s in self.samples - ] + "insert into {}booze values ('{}')".format(self.table_prefix, s) + for s in self.samples + ] return populate def test_fetchmany(self): @@ -543,78 +536,88 @@ class DatabaseAPI20Test(unittest.TestCase): cur = con.cursor() # cursor.fetchmany should raise an Error if called without - #issuing a query - self.assertRaises(self.driver.Error,cur.fetchmany,4) + # issuing a query + self.assertRaises(self.driver.Error, cur.fetchmany, 4) self.executeDDL1(cur) for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchmany() - self.assertEqual(len(r),1, - 'cursor.fetchmany retrieved incorrect number of rows, ' - 'default of arraysize is one.' - ) - cur.arraysize=10 - r = cur.fetchmany(3) # Should get 3 rows - self.assertEqual(len(r),3, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should get 2 more - self.assertEqual(len(r),2, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should be an empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence after ' - 'results are exhausted' + self.assertEqual( + len(r), + 1, + "cursor.fetchmany retrieved incorrect number of rows, " + "default of arraysize is one.", ) - self.assertTrue(cur.rowcount in (-1,6)) + cur.arraysize = 10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual( + len(r), 3, "cursor.fetchmany retrieved incorrect number of rows" + ) + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual( + len(r), 2, "cursor.fetchmany retrieved incorrect number of rows" + ) + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence after " + "results are exhausted", + ) + self.assertTrue(cur.rowcount in (-1, 6)) # Same as above, using cursor.arraysize - cur.arraysize=4 - cur.execute('select name from %sbooze' % self.table_prefix) - r = cur.fetchmany() # Should get 4 rows - self.assertEqual(len(r),4, - 'cursor.arraysize not being honoured by fetchmany' - ) - r = cur.fetchmany() # Should get 2 more - self.assertEqual(len(r),2) - r = cur.fetchmany() # Should be an empty sequence - self.assertEqual(len(r),0) - self.assertTrue(cur.rowcount in (-1,6)) + cur.arraysize = 4 + cur.execute("select name from %sbooze" % self.table_prefix) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual( + len(r), 4, "cursor.arraysize not being honoured by fetchmany" + ) + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r), 2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r), 0) + self.assertTrue(cur.rowcount in (-1, 6)) - cur.arraysize=6 - cur.execute('select name from %sbooze' % self.table_prefix) - rows = cur.fetchmany() # Should get all rows - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows),6) - self.assertEqual(len(rows),6) + cur.arraysize = 6 + cur.execute("select name from %sbooze" % self.table_prefix) + rows = cur.fetchmany() # Should get all rows + self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual(len(rows), 6) + self.assertEqual(len(rows), 6) rows = [r[0] for r in rows] rows.sort() # Make sure we get the right data back out - for i in range(0,6): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved by cursor.fetchmany' - ) - - rows = cur.fetchmany() # Should return an empty list - self.assertEqual(len(rows),0, - 'cursor.fetchmany should return an empty sequence if ' - 'called after the whole result set has been fetched' + for i in range(0, 6): + self.assertEqual( + rows[i], + self.samples[i], + "incorrect data retrieved by cursor.fetchmany", ) - self.assertTrue(cur.rowcount in (-1,6)) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual( + len(rows), + 0, + "cursor.fetchmany should return an empty sequence if " + "called after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, 6)) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) - r = cur.fetchmany() # Should get empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence if ' - 'query retrieved no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbarflys" % self.table_prefix) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence if " + "query retrieved no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) finally: con.close() @@ -634,36 +637,41 @@ class DatabaseAPI20Test(unittest.TestCase): # cursor.fetchall should raise an Error if called # after executing a statement that cannot return rows - self.assertRaises(self.driver.Error,cur.fetchall) + self.assertRaises(self.driver.Error, cur.fetchall) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - self.assertEqual(len(rows),len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) + self.assertEqual( + len(rows), + len(self.samples), + "cursor.fetchall did not retrieve all rows", + ) rows = [r[0] for r in rows] rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'cursor.fetchall retrieved incorrect rows' + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows" ) rows = cur.fetchall() self.assertEqual( - len(rows),0, - 'cursor.fetchall should return an empty list if called ' - 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,len(self.samples))) + len(rows), + 0, + "cursor.fetchall should return an empty list if called " + "after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute("select name from %sbarflys" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,0)) - self.assertEqual(len(rows),0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) + self.assertTrue(cur.rowcount in (-1, 0)) + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if " + "a select query returns no rows", + ) finally: con.close() @@ -676,91 +684,91 @@ class DatabaseAPI20Test(unittest.TestCase): for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) - rows1 = cur.fetchone() + cur.execute("select name from %sbooze" % self.table_prefix) + rows1 = cur.fetchone() rows23 = cur.fetchmany(2) - rows4 = cur.fetchone() + rows4 = cur.fetchone() rows56 = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows23),2, - 'fetchmany returned incorrect number of rows' - ) - self.assertEqual(len(rows56),2, - 'fetchall returned incorrect number of rows' - ) + self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual( + len(rows23), 2, "fetchmany returned incorrect number of rows" + ) + self.assertEqual( + len(rows56), 2, "fetchall returned incorrect number of rows" + ) rows = [rows1[0]] - rows.extend([rows23[0][0],rows23[1][0]]) + rows.extend([rows23[0][0], rows23[1][0]]) rows.append(rows4[0]) - rows.extend([rows56[0][0],rows56[1][0]]) + rows.extend([rows56[0][0], rows56[1][0]]) rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved or inserted' - ) + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "incorrect data retrieved or inserted" + ) finally: con.close() - def help_nextset_setUp(self,cur): - ''' Should create a procedure called deleteme + def help_nextset_setUp(self, cur): + """ Should create a procedure called deleteme that returns two result sets, first the number of rows in booze then "name from booze" - ''' - raise NotImplementedError('Helper not implemented') - #sql=""" + """ + raise NotImplementedError("Helper not implemented") + # sql=""" # create procedure deleteme as # begin # select count(*) from booze # select name from booze # end - #""" - #cur.execute(sql) + # """ + # cur.execute(sql) - def help_nextset_tearDown(self,cur): - 'If cleaning up is needed after nextSetTest' - raise NotImplementedError('Helper not implemented') - #cur.execute("drop procedure deleteme") + def help_nextset_tearDown(self, cur): + "If cleaning up is needed after nextSetTest" + raise NotImplementedError("Helper not implemented") + # cur.execute("drop procedure deleteme") def test_nextset(self): con = self._connect() try: cur = con.cursor() - if not hasattr(cur,'nextset'): + if not hasattr(cur, "nextset"): return try: self.executeDDL1(cur) - sql=self._populate() + sql = self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) - cur.callproc('deleteme') - numberofrows=cur.fetchone() - assert numberofrows[0]== len(self.samples) + cur.callproc("deleteme") + numberofrows = cur.fetchone() + assert numberofrows[0] == len(self.samples) assert cur.nextset() - names=cur.fetchall() + names = cur.fetchall() assert len(names) == len(self.samples) - s=cur.nextset() - assert s == None,'No more return sets, should return None' + s = cur.nextset() + assert s is None, "No more return sets, should return None" finally: self.help_nextset_tearDown(cur) finally: con.close() - def test_nextset(self): - raise NotImplementedError('Drivers need to override this test') + def test_nextset(self): # noqa: F811 + raise NotImplementedError("Drivers need to override this test") def test_arraysize(self): # Not much here - rest of the tests for this are in test_fetchmany con = self._connect() try: cur = con.cursor() - self.assertTrue(hasattr(cur,'arraysize'), - 'cursor.arraysize must be defined' - ) + self.assertTrue( + hasattr(cur, "arraysize"), "cursor.arraysize must be defined" + ) finally: con.close() @@ -768,8 +776,8 @@ class DatabaseAPI20Test(unittest.TestCase): con = self._connect() try: cur = con.cursor() - cur.setinputsizes( (25,) ) - self._paraminsert(cur) # Make sure cursor still works + cur.setinputsizes((25,)) + self._paraminsert(cur) # Make sure cursor still works finally: con.close() @@ -779,75 +787,74 @@ class DatabaseAPI20Test(unittest.TestCase): try: cur = con.cursor() cur.setoutputsize(1000) - cur.setoutputsize(2000,0) - self._paraminsert(cur) # Make sure the cursor still works + cur.setoutputsize(2000, 0) + self._paraminsert(cur) # Make sure the cursor still works finally: con.close() def test_setoutputsize(self): # Real test for setoutputsize is driver dependant - raise NotImplementedError('Driver need to override this test') + raise NotImplementedError("Driver need to override this test") def test_None(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) - cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("insert into %sbooze values (NULL)" % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchall() - self.assertEqual(len(r),1) - self.assertEqual(len(r[0]),1) - self.assertEqual(r[0][0],None,'NULL value not returned as None') + self.assertEqual(len(r), 1) + self.assertEqual(len(r[0]), 1) + self.assertEqual(r[0][0], None, "NULL value not returned as None") finally: con.close() def test_Date(self): - d1 = self.driver.Date(2002,12,25) - d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + d1 = self.driver.Date(2002, 12, 25) # noqa F841 + d2 = self.driver.DateFromTicks( # noqa F841 + time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)) + ) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(d1),str(d2)) def test_Time(self): - t1 = self.driver.Time(13,45,30) - t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + t1 = self.driver.Time(13, 45, 30) # noqa F841 + t2 = self.driver.TimeFromTicks( # noqa F841 + time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)) + ) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Timestamp(self): - t1 = self.driver.Timestamp(2002,12,25,13,45,30) - t2 = self.driver.TimestampFromTicks( - time.mktime((2002,12,25,13,45,30,0,0,0)) - ) + t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) # noqa F841 + t2 = self.driver.TimestampFromTicks( # noqa F841 + time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) + ) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Binary(self): - b = self.driver.Binary(b'Something') - b = self.driver.Binary(b'') + b = self.driver.Binary(b"Something") + b = self.driver.Binary(b"") # noqa F841 def test_STRING(self): - self.assertTrue(hasattr(self.driver,'STRING'), - 'module.STRING must be defined' - ) + self.assertTrue(hasattr(self.driver, "STRING"), "module.STRING must be defined") def test_BINARY(self): - self.assertTrue(hasattr(self.driver,'BINARY'), - 'module.BINARY must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "BINARY"), "module.BINARY must be defined." + ) def test_NUMBER(self): - self.assertTrue(hasattr(self.driver,'NUMBER'), - 'module.NUMBER must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "NUMBER"), "module.NUMBER must be defined." + ) def test_DATETIME(self): - self.assertTrue(hasattr(self.driver,'DATETIME'), - 'module.DATETIME must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "DATETIME"), "module.DATETIME must be defined." + ) def test_ROWID(self): - self.assertTrue(hasattr(self.driver,'ROWID'), - 'module.ROWID must be defined.' - ) - + self.assertTrue(hasattr(self.driver, "ROWID"), "module.ROWID must be defined.") diff --git a/tests/test_MySQLdb_capabilities.py b/tests/test_MySQLdb_capabilities.py index d5be511..fe9ef03 100644 --- a/tests/test_MySQLdb_capabilities.py +++ b/tests/test_MySQLdb_capabilities.py @@ -1,23 +1,23 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- import capabilities from datetime import timedelta from contextlib import closing import unittest import MySQLdb -from MySQLdb import cursors from configdb import connection_factory import warnings -warnings.filterwarnings('ignore') +warnings.filterwarnings("ignore") class test_MySQLdb(capabilities.DatabaseTest): db_module = MySQLdb connect_args = () - connect_kwargs = dict(use_unicode=True, sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL") + connect_kwargs = dict( + use_unicode=True, sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL" + ) create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8" leak_test = False @@ -25,97 +25,113 @@ class test_MySQLdb(capabilities.DatabaseTest): return "`%s`" % ident def test_TIME(self): - def generator(row,col): - return timedelta(0, row*8000) - self.check_data_integrity( - ('col1 TIME',), - generator) + def generator(row, col): + return timedelta(0, row * 8000) + + self.check_data_integrity(("col1 TIME",), generator) def test_TINYINT(self): # Number data def generator(row, col): - v = (row*row) % 256 + v = (row * row) % 256 if v > 127: - v = v-256 + v = v - 256 return v - self.check_data_integrity( - ('col1 TINYINT',), - generator) + + self.check_data_integrity(("col1 TINYINT",), generator) def test_stored_procedures(self): db = self.connection c = self.cursor - self.create_table(('pos INT', 'tree CHAR(20)')) - c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table, - list(enumerate('ash birch cedar Lärche pine'.split()))) + self.create_table(("pos INT", "tree CHAR(20)")) + c.executemany( + "INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table, + list(enumerate("ash birch cedar Lärche pine".split())), + ) db.commit() - c.execute(""" + c.execute( + """ CREATE PROCEDURE test_sp(IN t VARCHAR(255)) BEGIN SELECT pos FROM %s WHERE tree = t; END - """ % self.table) + """ + % self.table + ) db.commit() - c.callproc('test_sp', ('Lärche',)) + c.callproc("test_sp", ("Lärche",)) rows = c.fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0][0], 3) c.nextset() c.execute("DROP PROCEDURE test_sp") - c.execute('drop table %s' % (self.table)) + c.execute("drop table %s" % (self.table)) def test_small_CHAR(self): # Character data - def generator(row,col): - i = (row*col+62)%256 - if i == 62: return '' - if i == 63: return None + def generator(row, col): + i = (row * col + 62) % 256 + if i == 62: + return "" + if i == 63: + return None return chr(i) - self.check_data_integrity( - ('col1 char(1)','col2 char(1)'), - generator) + + self.check_data_integrity(("col1 char(1)", "col2 char(1)"), generator) def test_BIT(self): c = self.cursor try: - c.execute("""create table test_BIT ( + c.execute( + """create table test_BIT ( b3 BIT(3), b7 BIT(10), - b64 BIT(64))""") + b64 BIT(64))""" + ) - one64 = '1'*64 + one64 = "1" * 64 c.execute( "insert into test_BIT (b3, b7, b64)" - " VALUES (b'011', b'1111111111', b'%s')" - % one64) + " VALUES (b'011', b'1111111111', b'%s')" % one64 + ) c.execute("SELECT b3, b7, b64 FROM test_BIT") row = c.fetchone() - self.assertEqual(row[0], b'\x03') - self.assertEqual(row[1], b'\x03\xff') - self.assertEqual(row[2], b'\xff'*8) + self.assertEqual(row[0], b"\x03") + self.assertEqual(row[1], b"\x03\xff") + self.assertEqual(row[2], b"\xff" * 8) finally: c.execute("drop table if exists test_BIT") def test_MULTIPOLYGON(self): c = self.cursor try: - c.execute("""create table test_MULTIPOLYGON ( + c.execute( + """create table test_MULTIPOLYGON ( id INTEGER PRIMARY KEY, - border MULTIPOLYGON)""") + border MULTIPOLYGON)""" + ) c.execute( - "insert into test_MULTIPOLYGON (id, border)" - " VALUES (1, GeomFromText('MULTIPOLYGON(((1 1, 1 -1, -1 -1, -1 1, 1 1)),((1 1, 3 1, 3 3, 1 3, 1 1)))'))" + """ +INSERT INTO test_MULTIPOLYGON + (id, border) +VALUES (1, + Geomfromtext( +'MULTIPOLYGON(((1 1, 1 -1, -1 -1, -1 1, 1 1)),((1 1, 3 1, 3 3, 1 3, 1 1)))')) +""" ) c.execute("SELECT id, AsText(border) FROM test_MULTIPOLYGON") row = c.fetchone() self.assertEqual(row[0], 1) - self.assertEqual(row[1], 'MULTIPOLYGON(((1 1,1 -1,-1 -1,-1 1,1 1)),((1 1,3 1,3 3,1 3,1 1)))') + self.assertEqual( + row[1], + "MULTIPOLYGON(((1 1,1 -1,-1 -1,-1 1,1 1)),((1 1,3 1,3 3,1 3,1 1)))", + ) c.execute("SELECT id, AsWKB(border) FROM test_MULTIPOLYGON") row = c.fetchone() @@ -131,19 +147,21 @@ class test_MySQLdb(capabilities.DatabaseTest): def test_bug_2671682(self): from MySQLdb.constants import ER + try: - self.cursor.execute("describe some_non_existent_table"); + self.cursor.execute("describe some_non_existent_table") except self.connection.ProgrammingError as msg: self.assertTrue(str(ER.NO_SUCH_TABLE) in str(msg)) def test_bug_3514287(self): c = self.cursor try: - c.execute("""create table bug_3541287 ( + c.execute( + """create table bug_3541287 ( c1 CHAR(10), - t1 TIMESTAMP)""") - c.execute("insert into bug_3541287 (c1,t1) values (%s, NOW())", - ("blah",)) + t1 TIMESTAMP)""" + ) + c.execute("insert into bug_3541287 (c1,t1) values (%s, NOW())", ("blah",)) finally: c.execute("drop table if exists bug_3541287") @@ -164,22 +182,25 @@ class test_MySQLdb(capabilities.DatabaseTest): for binary_prefix in (True, False, None): kwargs = self.connect_kwargs.copy() # needs to be set to can guarantee CHARSET response for normal strings - kwargs['charset'] = 'utf8' - if binary_prefix != None: - kwargs['binary_prefix'] = binary_prefix + kwargs["charset"] = "utf8" + if binary_prefix is not None: + kwargs["binary_prefix"] = binary_prefix with closing(connection_factory(**kwargs)) as conn: with closing(conn.cursor()) as c: - c.execute('SELECT CHARSET(%s)', (MySQLdb.Binary(b'raw bytes'),)) - self.assertEqual(c.fetchall()[0][0], 'binary' if binary_prefix else 'utf8') + c.execute("SELECT CHARSET(%s)", (MySQLdb.Binary(b"raw bytes"),)) + self.assertEqual( + c.fetchall()[0][0], "binary" if binary_prefix else "utf8" + ) # normal strings should not get prefix - c.execute('SELECT CHARSET(%s)', ('str',)) - self.assertEqual(c.fetchall()[0][0], 'utf8') + c.execute("SELECT CHARSET(%s)", ("str",)) + self.assertEqual(c.fetchall()[0][0], "utf8") -if __name__ == '__main__': +if __name__ == "__main__": if test_MySQLdb.leak_test: import gc + gc.enable() gc.set_debug(gc.DEBUG_LEAK) unittest.main() diff --git a/tests/test_MySQLdb_dbapi20.py b/tests/test_MySQLdb_dbapi20.py index 1e808bd..6b3a378 100644 --- a/tests/test_MySQLdb_dbapi20.py +++ b/tests/test_MySQLdb_dbapi20.py @@ -4,17 +4,22 @@ import unittest import MySQLdb from configdb import connection_kwargs import warnings + warnings.simplefilter("ignore") class test_MySQLdb(dbapi20.DatabaseAPI20Test): driver = MySQLdb connect_args = () - connect_kw_args = connection_kwargs(dict(sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")) + connect_kw_args = connection_kwargs( + dict(sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL") + ) - def test_setoutputsize(self): pass - def test_setoutputsize_basic(self): pass - def test_nextset(self): pass + def test_setoutputsize(self): + pass + + def test_setoutputsize_basic(self): + pass """The tests on fetchone and fetchall and rowcount bogusly test for an exception if the statement cannot return a @@ -36,36 +41,41 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): # cursor.fetchall should raise an Error if called # after executing a statement that cannot return rows - #self.assertRaises(self.driver.Error,cur.fetchall) + # self.assertRaises(self.driver.Error,cur.fetchall) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - self.assertEqual(len(rows),len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) + self.assertEqual( + len(rows), + len(self.samples), + "cursor.fetchall did not retrieve all rows", + ) rows = [r[0] for r in rows] rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'cursor.fetchall retrieved incorrect rows' + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows" ) rows = cur.fetchall() self.assertEqual( - len(rows),0, - 'cursor.fetchall should return an empty list if called ' - 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,len(self.samples))) + len(rows), + 0, + "cursor.fetchall should return an empty list if called " + "after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute("select name from %sbarflys" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,0)) - self.assertEqual(len(rows),0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) + self.assertTrue(cur.rowcount in (-1, 0)) + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if " + "a select query returns no rows", + ) finally: con.close() @@ -77,39 +87,42 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): # cursor.fetchone should raise an Error if called before # executing a select-type query - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after # executing a query that cannot return rows self.executeDDL1(cur) -## self.assertRaises(self.driver.Error,cur.fetchone) + # self.assertRaises(self.driver.Error,cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if a query retrieves " "no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) # cursor.fetchone should raise an Error if called after # executing a query that cannot return rows - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) -## self.assertRaises(self.driver.Error,cur.fetchone) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + # self.assertRaises(self.driver.Error,cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchone() - self.assertEqual(len(r),1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0],'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) -## self.assertEqual(cur.fetchone(),None, -## 'cursor.fetchone should return None if no more rows available' -## ) - self.assertTrue(cur.rowcount in (-1,1)) + self.assertEqual( + len(r), 1, "cursor.fetchone should have retrieved a single row" + ) + self.assertEqual( + r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data" + ) + # self.assertEqual( + # cur.fetchone(), + # None, + # "cursor.fetchone should return None if no more rows available", + # ) + self.assertTrue(cur.rowcount in (-1, 1)) finally: con.close() @@ -119,81 +132,93 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): try: cur = con.cursor() self.executeDDL1(cur) -## self.assertEqual(cur.rowcount,-1, -## 'cursor.rowcount should be -1 after executing no-result ' -## 'statements' -## ) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) -## self.assertTrue(cur.rowcount in (-1,1), -## 'cursor.rowcount should == number or rows inserted, or ' -## 'set to -1 after executing an insert statement' -## ) + # self.assertEqual(cur.rowcount,-1, + # 'cursor.rowcount should be -1 after executing no-result ' + # 'statements' + # ) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + # self.assertTrue(cur.rowcount in (-1,1), + # 'cursor.rowcount should == number or rows inserted, or ' + # 'set to -1 after executing an insert statement' + # ) cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number of rows returned, or " + "set to -1 after executing a select statement", + ) self.executeDDL2(cur) -## self.assertEqual(cur.rowcount,-1, -## 'cursor.rowcount not being reset to -1 after executing ' -## 'no-result statements' -## ) + # self.assertEqual(cur.rowcount,-1, + # 'cursor.rowcount not being reset to -1 after executing ' + # 'no-result statements' + # ) finally: con.close() def test_callproc(self): - pass # performed in test_MySQL_capabilities + pass # performed in test_MySQL_capabilities - def help_nextset_setUp(self,cur): - ''' Should create a procedure called deleteme + def help_nextset_setUp(self, cur): + """ Should create a procedure called deleteme that returns two result sets, first the number of rows in booze then "name from booze" - ''' - sql=""" + """ + sql = """ create procedure deleteme() begin select count(*) from %(tp)sbooze; select name from %(tp)sbooze; end - """ % dict(tp=self.table_prefix) + """ % dict( + tp=self.table_prefix + ) cur.execute(sql) - def help_nextset_tearDown(self,cur): - 'If cleaning up is needed after nextSetTest' + def help_nextset_tearDown(self, cur): + "If cleaning up is needed after nextSetTest" cur.execute("drop procedure deleteme") def test_nextset(self): - #from warnings import warn + # from warnings import warn + con = self._connect() try: cur = con.cursor() - if not hasattr(cur, 'nextset'): + if not hasattr(cur, "nextset"): return try: self.executeDDL1(cur) - sql=self._populate() + sql = self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) - cur.callproc('deleteme') - numberofrows=cur.fetchone() - assert numberofrows[0]== len(self.samples) + cur.callproc("deleteme") + numberofrows = cur.fetchone() + assert numberofrows[0] == len(self.samples) assert cur.nextset() - names=cur.fetchall() + names = cur.fetchall() assert len(names) == len(self.samples) - s=cur.nextset() + s = cur.nextset() if s: empty = cur.fetchall() - self.assertEqual(len(empty), 0, - "non-empty result set after other result sets") - #warn("Incompatibility: MySQL returns an empty result set for the CALL itself", - # Warning) - #assert s == None,'No more return sets, should return None' + self.assertEqual( + len(empty), 0, "non-empty result set after other result sets" + ) + # warn( + # ": ".join( + # [ + # "Incompatibility", + # "MySQL returns an empty result set for the CALL itself" + # ] + # ), + # Warning, + # ) + # assert s == None, "No more return sets, should return None" finally: self.help_nextset_tearDown(cur) @@ -201,5 +226,5 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): con.close() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_MySQLdb_nonstandard.py b/tests/test_MySQLdb_nonstandard.py index c5cacbe..c517dad 100644 --- a/tests/test_MySQLdb_nonstandard.py +++ b/tests/test_MySQLdb_nonstandard.py @@ -5,6 +5,7 @@ import MySQLdb from MySQLdb.constants import FIELD_TYPE from configdb import connection_factory import warnings + warnings.simplefilter("ignore") @@ -36,10 +37,12 @@ class TestCoreModule(unittest.TestCase): self.assertTrue(isinstance(_mysql.get_client_info(), str)) def test_escape_string(self): - self.assertEqual(_mysql.escape_string(b'foo"bar'), - b'foo\\"bar', "escape byte string") - self.assertEqual(_mysql.escape_string(u'foo"bar'), - b'foo\\"bar', "escape unicode string") + self.assertEqual( + _mysql.escape_string(b'foo"bar'), b'foo\\"bar', "escape byte string" + ) + self.assertEqual( + _mysql.escape_string('foo"bar'), b'foo\\"bar', "escape unicode string" + ) class CoreAPI(unittest.TestCase): @@ -53,42 +56,49 @@ class CoreAPI(unittest.TestCase): def test_thread_id(self): tid = self.conn.thread_id() - self.assertTrue(isinstance(tid, int), - "thread_id didn't return an int.") + self.assertTrue(isinstance(tid, int), "thread_id didn't return an int.") - self.assertRaises(TypeError, self.conn.thread_id, ('evil',), - "thread_id shouldn't accept arguments.") + self.assertRaises( + TypeError, + self.conn.thread_id, + ("evil",), + "thread_id shouldn't accept arguments.", + ) def test_affected_rows(self): - self.assertEqual(self.conn.affected_rows(), 0, - "Should return 0 before we do anything.") + self.assertEqual( + self.conn.affected_rows(), 0, "Should return 0 before we do anything." + ) - - #def test_debug(self): - ## FIXME Only actually tests if you lack SUPER - #self.assertRaises(MySQLdb.OperationalError, - #self.conn.dump_debug_info) + # def test_debug(self): + # (FIXME) Only actually tests if you lack SUPER + # self.assertRaises(MySQLdb.OperationalError, + # self.conn.dump_debug_info) def test_charset_name(self): - self.assertTrue(isinstance(self.conn.character_set_name(), str), - "Should return a string.") + self.assertTrue( + isinstance(self.conn.character_set_name(), str), "Should return a string." + ) def test_host_info(self): - self.assertTrue(isinstance(self.conn.get_host_info(), str), - "Should return a string.") + self.assertTrue( + isinstance(self.conn.get_host_info(), str), "Should return a string." + ) def test_proto_info(self): - self.assertTrue(isinstance(self.conn.get_proto_info(), int), - "Should return an int.") + self.assertTrue( + isinstance(self.conn.get_proto_info(), int), "Should return an int." + ) def test_server_info(self): - self.assertTrue(isinstance(self.conn.get_server_info(), str), - "Should return a string.") + self.assertTrue( + isinstance(self.conn.get_server_info(), str), "Should return a string." + ) def test_client_flag(self): conn = connection_factory( - use_unicode=True, - client_flag=MySQLdb.constants.CLIENT.FOUND_ROWS) + use_unicode=True, client_flag=MySQLdb.constants.CLIENT.FOUND_ROWS + ) self.assertIsInstance(conn.client_flag, int) self.assertTrue(conn.client_flag & MySQLdb.constants.CLIENT.FOUND_ROWS) diff --git a/tests/test_MySQLdb_times.py b/tests/test_MySQLdb_times.py index d9d3e02..fdc35ff 100644 --- a/tests/test_MySQLdb_times.py +++ b/tests/test_MySQLdb_times.py @@ -6,104 +6,141 @@ from datetime import time, date, datetime, timedelta from MySQLdb import times import warnings + warnings.simplefilter("ignore") class TestX_or_None(unittest.TestCase): def test_date_or_none(self): - assert times.Date_or_None('1969-01-01') == date(1969, 1, 1) - assert times.Date_or_None('2015-01-01') == date(2015, 1, 1) - assert times.Date_or_None('2015-12-13') == date(2015, 12, 13) + assert times.Date_or_None("1969-01-01") == date(1969, 1, 1) + assert times.Date_or_None("2015-01-01") == date(2015, 1, 1) + assert times.Date_or_None("2015-12-13") == date(2015, 12, 13) - assert times.Date_or_None('') is None - assert times.Date_or_None('fail') is None - assert times.Date_or_None('2015-12') is None - assert times.Date_or_None('2015-12-40') is None - assert times.Date_or_None('0000-00-00') is None + assert times.Date_or_None("") is None + assert times.Date_or_None("fail") is None + assert times.Date_or_None("2015-12") is None + assert times.Date_or_None("2015-12-40") is None + assert times.Date_or_None("0000-00-00") is None def test_time_or_none(self): - assert times.Time_or_None('00:00:00') == time(0, 0) - assert times.Time_or_None('01:02:03') == time(1, 2, 3) - assert times.Time_or_None('01:02:03.123456') == time(1, 2, 3, 123456) + assert times.Time_or_None("00:00:00") == time(0, 0) + assert times.Time_or_None("01:02:03") == time(1, 2, 3) + assert times.Time_or_None("01:02:03.123456") == time(1, 2, 3, 123456) - assert times.Time_or_None('') is None - assert times.Time_or_None('fail') is None - assert times.Time_or_None('24:00:00') is None - assert times.Time_or_None('01:02:03.123456789') is None + assert times.Time_or_None("") is None + assert times.Time_or_None("fail") is None + assert times.Time_or_None("24:00:00") is None + assert times.Time_or_None("01:02:03.123456789") is None def test_datetime_or_none(self): - assert times.DateTime_or_None('1000-01-01') == date(1000, 1, 1) - assert times.DateTime_or_None('2015-12-13') == date(2015, 12, 13) - assert times.DateTime_or_None('2015-12-13 01:02') == datetime(2015, 12, 13, 1, 2) - assert times.DateTime_or_None('2015-12-13T01:02') == datetime(2015, 12, 13, 1, 2) - assert times.DateTime_or_None('2015-12-13 01:02:03') == datetime(2015, 12, 13, 1, 2, 3) - assert times.DateTime_or_None('2015-12-13T01:02:03') == datetime(2015, 12, 13, 1, 2, 3) - assert times.DateTime_or_None('2015-12-13 01:02:03.123') == datetime(2015, 12, 13, 1, 2, 3, 123000) - assert times.DateTime_or_None('2015-12-13 01:02:03.000123') == datetime(2015, 12, 13, 1, 2, 3, 123) - assert times.DateTime_or_None('2015-12-13 01:02:03.123456') == datetime(2015, 12, 13, 1, 2, 3, 123456) - assert times.DateTime_or_None('2015-12-13T01:02:03.123456') == datetime(2015, 12, 13, 1, 2, 3, 123456) + assert times.DateTime_or_None("1000-01-01") == date(1000, 1, 1) + assert times.DateTime_or_None("2015-12-13") == date(2015, 12, 13) + assert times.DateTime_or_None("2015-12-13 01:02") == datetime( + 2015, 12, 13, 1, 2 + ) + assert times.DateTime_or_None("2015-12-13T01:02") == datetime( + 2015, 12, 13, 1, 2 + ) + assert times.DateTime_or_None("2015-12-13 01:02:03") == datetime( + 2015, 12, 13, 1, 2, 3 + ) + assert times.DateTime_or_None("2015-12-13T01:02:03") == datetime( + 2015, 12, 13, 1, 2, 3 + ) + assert times.DateTime_or_None("2015-12-13 01:02:03.123") == datetime( + 2015, 12, 13, 1, 2, 3, 123000 + ) + assert times.DateTime_or_None("2015-12-13 01:02:03.000123") == datetime( + 2015, 12, 13, 1, 2, 3, 123 + ) + assert times.DateTime_or_None("2015-12-13 01:02:03.123456") == datetime( + 2015, 12, 13, 1, 2, 3, 123456 + ) + assert times.DateTime_or_None("2015-12-13T01:02:03.123456") == datetime( + 2015, 12, 13, 1, 2, 3, 123456 + ) - assert times.DateTime_or_None('') is None - assert times.DateTime_or_None('fail') is None - assert times.DateTime_or_None('0000-00-00 00:00:00') is None - assert times.DateTime_or_None('0000-00-00 00:00:00.000000') is None - assert times.DateTime_or_None('2015-12-13T01:02:03.123456789') is None + assert times.DateTime_or_None("") is None + assert times.DateTime_or_None("fail") is None + assert times.DateTime_or_None("0000-00-00 00:00:00") is None + assert times.DateTime_or_None("0000-00-00 00:00:00.000000") is None + assert times.DateTime_or_None("2015-12-13T01:02:03.123456789") is None def test_timedelta_or_none(self): - assert times.TimeDelta_or_None('-1:0:0') == timedelta(0, -3600) - assert times.TimeDelta_or_None('1:0:0') == timedelta(0, 3600) - assert times.TimeDelta_or_None('12:55:30') == timedelta(0, 46530) - assert times.TimeDelta_or_None('12:55:30.123456') == timedelta(0, 46530, 123456) - assert times.TimeDelta_or_None('12:55:30.123456789') == timedelta(0, 46653, 456789) - assert times.TimeDelta_or_None('12:55:30.123456789123456') == timedelta(1429, 37719, 123456) + assert times.TimeDelta_or_None("-1:0:0") == timedelta(0, -3600) + assert times.TimeDelta_or_None("1:0:0") == timedelta(0, 3600) + assert times.TimeDelta_or_None("12:55:30") == timedelta(0, 46530) + assert times.TimeDelta_or_None("12:55:30.123456") == timedelta(0, 46530, 123456) + assert times.TimeDelta_or_None("12:55:30.123456789") == timedelta( + 0, 46653, 456789 + ) + assert times.TimeDelta_or_None("12:55:30.123456789123456") == timedelta( + 1429, 37719, 123456 + ) - assert times.TimeDelta_or_None('') is None - assert times.TimeDelta_or_None('0') is None - assert times.TimeDelta_or_None('fail') is None + assert times.TimeDelta_or_None("") is None + assert times.TimeDelta_or_None("0") is None + assert times.TimeDelta_or_None("fail") is None class TestTicks(unittest.TestCase): - @mock.patch('MySQLdb.times.localtime', side_effect=gmtime) + @mock.patch("MySQLdb.times.localtime", side_effect=gmtime) def test_date_from_ticks(self, mock): assert times.DateFromTicks(0) == date(1970, 1, 1) assert times.DateFromTicks(1430000000) == date(2015, 4, 25) - @mock.patch('MySQLdb.times.localtime', side_effect=gmtime) + @mock.patch("MySQLdb.times.localtime", side_effect=gmtime) def test_time_from_ticks(self, mock): assert times.TimeFromTicks(0) == time(0, 0, 0) assert times.TimeFromTicks(1431100000) == time(15, 46, 40) assert times.TimeFromTicks(1431100000.123) == time(15, 46, 40) - @mock.patch('MySQLdb.times.localtime', side_effect=gmtime) + @mock.patch("MySQLdb.times.localtime", side_effect=gmtime) def test_timestamp_from_ticks(self, mock): assert times.TimestampFromTicks(0) == datetime(1970, 1, 1, 0, 0, 0) assert times.TimestampFromTicks(1430000000) == datetime(2015, 4, 25, 22, 13, 20) - assert times.TimestampFromTicks(1430000000.123) == datetime(2015, 4, 25, 22, 13, 20) + assert times.TimestampFromTicks(1430000000.123) == datetime( + 2015, 4, 25, 22, 13, 20 + ) class TestToLiteral(unittest.TestCase): def test_datetime_to_literal(self): - assert times.DateTime2literal(datetime(2015, 12, 13), '') == b"'2015-12-13 00:00:00'" - assert times.DateTime2literal(datetime(2015, 12, 13, 11, 12, 13), '') == b"'2015-12-13 11:12:13'" - assert times.DateTime2literal(datetime(2015, 12, 13, 11, 12, 13, 123456), '') == b"'2015-12-13 11:12:13.123456'" + self.assertEquals( + times.DateTime2literal(datetime(2015, 12, 13), ""), b"'2015-12-13 00:00:00'" + ) + self.assertEquals( + times.DateTime2literal(datetime(2015, 12, 13, 11, 12, 13), ""), + b"'2015-12-13 11:12:13'", + ) + self.assertEquals( + times.DateTime2literal(datetime(2015, 12, 13, 11, 12, 13, 123456), ""), + b"'2015-12-13 11:12:13.123456'", + ) def test_datetimedelta_to_literal(self): d = datetime(2015, 12, 13, 1, 2, 3) - datetime(2015, 12, 13, 1, 2, 2) - assert times.DateTimeDelta2literal(d, '') == b"'0 0:0:1'" + assert times.DateTimeDelta2literal(d, "") == b"'0 0:0:1'" class TestFormat(unittest.TestCase): def test_format_timedelta(self): d = datetime(2015, 1, 1) - datetime(2015, 1, 1) - assert times.format_TIMEDELTA(d) == '0 0:0:0' + assert times.format_TIMEDELTA(d) == "0 0:0:0" d = datetime(2015, 1, 1, 10, 11, 12) - datetime(2015, 1, 1, 8, 9, 10) - assert times.format_TIMEDELTA(d) == '0 2:2:2' + assert times.format_TIMEDELTA(d) == "0 2:2:2" d = datetime(2015, 1, 1, 10, 11, 12) - datetime(2015, 1, 1, 11, 12, 13) - assert times.format_TIMEDELTA(d) == '-1 22:58:59' + assert times.format_TIMEDELTA(d) == "-1 22:58:59" def test_format_timestamp(self): - assert times.format_TIMESTAMP(datetime(2015, 2, 3)) == '2015-02-03 00:00:00' - assert times.format_TIMESTAMP(datetime(2015, 2, 3, 17, 18, 19)) == '2015-02-03 17:18:19' - assert times.format_TIMESTAMP(datetime(15, 2, 3, 17, 18, 19)) == '0015-02-03 17:18:19' + assert times.format_TIMESTAMP(datetime(2015, 2, 3)) == "2015-02-03 00:00:00" + self.assertEquals( + times.format_TIMESTAMP(datetime(2015, 2, 3, 17, 18, 19)), + "2015-02-03 17:18:19", + ) + self.assertEquals( + times.format_TIMESTAMP(datetime(15, 2, 3, 17, 18, 19)), + "0015-02-03 17:18:19", + ) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index ff96368..479f3e2 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,6 +1,4 @@ -from __future__ import print_function, absolute_import - -import pytest +# import pytest import MySQLdb.cursors from configdb import connection_factory @@ -8,6 +6,7 @@ from configdb import connection_factory _conns = [] _tables = [] + def connect(**kwargs): conn = connection_factory(**kwargs) _conns.append(conn) @@ -19,7 +18,7 @@ def teardown_function(function): c = _conns[0] cur = c.cursor() for t in _tables: - cur.execute("DROP TABLE %s" % (t,)) + cur.execute("DROP TABLE {}".format(t)) cur.close() del _tables[:] @@ -35,47 +34,71 @@ def test_executemany(): cursor.execute("create table test (data varchar(10))") _tables.append("test") - m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%s, %s)") - assert m is not None, 'error parse %s' - assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?' + m = MySQLdb.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%s, %s)" + ) + assert m is not None, "error parse %s" + assert m.group(3) == "", "group 3 not blank, bug in RE_INSERT_VALUES?" - m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)") - assert m is not None, 'error parse %(name)s' - assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?' + m = MySQLdb.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)" + ) + assert m is not None, "error parse %(name)s" + assert m.group(3) == "", "group 3 not blank, bug in RE_INSERT_VALUES?" - m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)") - assert m is not None, 'error parse %(id_name)s' - assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?' + m = MySQLdb.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)" + ) + assert m is not None, "error parse %(id_name)s" + assert m.group(3) == "", "group 3 not blank, bug in RE_INSERT_VALUES?" - m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update") - assert m is not None, 'error parse %(id_name)s' - assert m.group(3) == ' ON duplicate update', 'group 3 not ON duplicate update, bug in RE_INSERT_VALUES?' + m = MySQLdb.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update" + ) + assert m is not None, "error parse %(id_name)s" + assert ( + m.group(3) == " ON duplicate update" + ), "group 3 not ON duplicate update, bug in RE_INSERT_VALUES?" # https://github.com/PyMySQL/mysqlclient-python/issues/178 - m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO bloup(foo, bar)VALUES(%s, %s)") + m = MySQLdb.cursors.RE_INSERT_VALUES.match( + "INSERT INTO bloup(foo, bar)VALUES(%s, %s)" + ) assert m is not None - # cursor._executed myst bee "insert into test (data) values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)" + # cursor._executed myst bee + # """ + # insert into test (data) + # values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9) + # """ # list args data = range(10) cursor.executemany("insert into test (data) values (%s)", data) - assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %s not in one query' + assert cursor._executed.endswith( + b",(7),(8),(9)" + ), "execute many with %s not in one query" # dict args - data_dict = [{'data': i} for i in range(10)] + data_dict = [{"data": i} for i in range(10)] cursor.executemany("insert into test (data) values (%(data)s)", data_dict) - assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %(data)s not in one query' + assert cursor._executed.endswith( + b",(7),(8),(9)" + ), "execute many with %(data)s not in one query" # %% in column set - cursor.execute("""\ + cursor.execute( + """\ CREATE TABLE percent_test ( `A%` INTEGER, - `B%` INTEGER)""") + `B%` INTEGER)""" + ) try: q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)" assert MySQLdb.cursors.RE_INSERT_VALUES.match(q) is not None cursor.executemany(q, [(3, 4), (5, 6)]) - assert cursor._executed.endswith(b"(3, 4),(5, 6)"), "executemany with %% not in one query" + assert cursor._executed.endswith( + b"(3, 4),(5, 6)" + ), "executemany with %% not in one query" finally: cursor.execute("DROP TABLE IF EXISTS percent_test") @@ -84,7 +107,7 @@ def test_pyparam(): conn = connect() cursor = conn.cursor() - cursor.execute(u"SELECT %(a)s, %(b)s", {u'a': 1, u'b': 2}) + cursor.execute("SELECT %(a)s, %(b)s", {"a": 1, "b": 2}) assert cursor._executed == b"SELECT 1, 2" - cursor.execute(b"SELECT %(a)s, %(b)s", {b'a': 3, b'b': 4}) + cursor.execute(b"SELECT %(a)s, %(b)s", {b"a": 3, b"b": 4}) assert cursor._executed == b"SELECT 3, 4"