Merge pull request #83 from methane/fix/executemany-double-percent

Port executemany() implementation from PyMySQL
This commit is contained in:
INADA Naoki
2016-05-11 18:41:53 +09:00
3 changed files with 185 additions and 79 deletions

View File

@ -65,7 +65,6 @@ def numeric_part(s):
class Connection(_mysql.connection): class Connection(_mysql.connection):
"""MySQL Database Connection Object""" """MySQL Database Connection Object"""
default_cursor = cursors.Cursor default_cursor = cursors.Cursor
@ -278,6 +277,9 @@ class Connection(_mysql.connection):
return (cursorclass or self.cursorclass)(self) return (cursorclass or self.cursorclass)(self)
def query(self, query): def query(self, query):
# Since _mysql releases GIL while querying, we need immutable buffer.
if isinstance(query, bytearray):
query = bytes(query)
if self.waiter is not None: if self.waiter is not None:
self.send_query(query) self.send_query(query)
self.waiter(self.fileno()) self.waiter(self.fileno())
@ -353,6 +355,7 @@ class Connection(_mysql.connection):
self.store_result() self.store_result()
self.string_decoder.charset = py_charset self.string_decoder.charset = py_charset
self.unicode_literal.charset = py_charset self.unicode_literal.charset = py_charset
self.encoding = py_charset
def set_sql_mode(self, sql_mode): def set_sql_mode(self, sql_mode):
"""Set the connection sql_mode. See MySQL documentation for """Set the connection sql_mode. See MySQL documentation for

View File

@ -2,45 +2,34 @@
This module implements Cursors of various types for MySQLdb. By This module implements Cursors of various types for MySQLdb. By
default, MySQLdb uses the Cursor class. default, MySQLdb uses the Cursor class.
""" """
from __future__ import print_function, absolute_import
from functools import partial
import re import re
import sys import sys
PY2 = sys.version_info[0] == 2
from MySQLdb.compat import unicode from MySQLdb.compat import unicode
from _mysql_exceptions import (
Warning, Error, InterfaceError, DataError,
DatabaseError, OperationalError, IntegrityError, InternalError,
NotSupportedError, ProgrammingError)
restr = r"""
\s
values
\s*
(
\(
[^()']*
(?:
(?:
(?:\(
# ( - editor highlighting helper
.*
\))
|
'
[^\\']*
(?:\\.[^\\']*)*
'
)
[^()']*
)*
\)
)
"""
insert_values = re.compile(restr, re.S | re.I | re.X) PY2 = sys.version_info[0] == 2
if PY2:
text_type = unicode
else:
text_type = str
from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \
DatabaseError, OperationalError, IntegrityError, InternalError, \ #: Regular expression for :meth:`Cursor.executemany`.
NotSupportedError, ProgrammingError #: executemany only suports simple bulk insert.
#: You can use it to load large dataset.
RE_INSERT_VALUES = re.compile(
r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
r"(\s*(?:ON DUPLICATE.*)?)\Z",
re.IGNORECASE | re.DOTALL)
class BaseCursor(object): class BaseCursor(object):
@ -60,6 +49,12 @@ class BaseCursor(object):
default number of rows fetchmany() will fetch default number of rows fetchmany() will fetch
""" """
#: Max stetement size which :meth:`executemany` generates.
#:
#: Max size of allowed statement is max_allowed_packet - packet_header_size.
#: Default value of max_allowed_packet is 1048576.
max_stmt_length = 64*1024
from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \ from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \
DatabaseError, DataError, OperationalError, IntegrityError, \ DatabaseError, DataError, OperationalError, IntegrityError, \
InternalError, ProgrammingError, NotSupportedError InternalError, ProgrammingError, NotSupportedError
@ -102,6 +97,32 @@ class BaseCursor(object):
del exc_info del exc_info
self.close() self.close()
def _ensure_bytes(self, x, encoding=None):
if isinstance(x, text_type):
x = x.encode(encoding)
elif isinstance(x, (tuple, list)):
x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
return x
def _escape_args(self, args, conn):
ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
if isinstance(args, (tuple, list)):
if PY2:
args = tuple(map(ensure_bytes, args))
return tuple(conn.literal(arg) for arg in args)
elif isinstance(args, dict):
if PY2:
args = dict((ensure_bytes(key), ensure_bytes(val)) for
(key, val) in args.items())
return dict((key, conn.literal(val)) for (key, val) in args.items())
else:
# If it's not a dictionary let's try escaping it anyways.
# Worst case it will throw a Value error
if PY2:
args = ensure_bytes(args)
return conn.literal(args)
def _check_executed(self): def _check_executed(self):
if not self._executed: if not self._executed:
self.errorhandler(self, ProgrammingError, "execute() first") self.errorhandler(self, ProgrammingError, "execute() first")
@ -230,62 +251,70 @@ class BaseCursor(object):
return res return res
def executemany(self, query, args): def executemany(self, query, args):
# type: (str, list) -> int
"""Execute a multi-row query. """Execute a multi-row query.
query -- string, query to execute on server :param query: query to execute on server
:param args: Sequence of sequences or mappings. It is used as parameter.
args :return: Number of rows affected, if any.
Sequence of sequences or mappings, parameters to use with
query.
Returns long integer rows affected, if any.
This method improves performance on multiple-row INSERT and This method improves performance on multiple-row INSERT and
REPLACE. Otherwise it is equivalent to looping over args with REPLACE. Otherwise it is equivalent to looping over args with
execute(). execute().
""" """
del self.messages[:] del self.messages[:]
db = self._get_db()
if not args: return if not args:
if PY2 and isinstance(query, unicode): return
query = query.encode(db.unicode_literal.charset)
elif not PY2 and isinstance(query, bytes): m = RE_INSERT_VALUES.match(query)
query = query.decode(db.unicode_literal.charset) if m:
m = insert_values.search(query) q_prefix = m.group(1) % ()
if not m: q_values = m.group(2).rstrip()
r = 0 q_postfix = m.group(3) or ''
for a in args: assert q_values[0] == '(' and q_values[-1] == ')'
r = r + self.execute(query, a) return self._do_execute_many(q_prefix, q_values, q_postfix, args,
return r self.max_stmt_length,
p = m.start(1) self._get_db().encoding)
e = m.end(1)
qv = m.group(1) self.rowcount = sum(self.execute(query, arg) for arg in args)
try: return self.rowcount
q = []
for a in args: def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
if isinstance(a, dict): conn = self._get_db()
q.append(qv % dict((key, db.literal(item)) escape = self._escape_args
for key, item in a.items())) if isinstance(prefix, text_type):
prefix = prefix.encode(encoding)
if PY2 and isinstance(values, text_type):
values = values.encode(encoding)
if isinstance(postfix, text_type):
postfix = postfix.encode(encoding)
sql = bytearray(prefix)
args = iter(args)
v = values % escape(next(args), conn)
if isinstance(v, text_type):
if PY2:
v = v.encode(encoding)
else: else:
q.append(qv % tuple([db.literal(item) for item in a])) v = v.encode(encoding, 'surrogateescape')
except TypeError as msg: sql += v
if msg.args[0] in ("not enough arguments for format string", rows = 0
"not all arguments converted"): for arg in args:
self.errorhandler(self, ProgrammingError, msg.args[0]) v = values % escape(arg, conn)
if isinstance(v, text_type):
if PY2:
v = v.encode(encoding)
else: else:
self.errorhandler(self, TypeError, msg) v = v.encode(encoding, 'surrogateescape')
except (SystemExit, KeyboardInterrupt): if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
raise rows += self.execute(sql + postfix)
except: sql = bytearray(prefix)
exc, value = sys.exc_info()[:2] else:
self.errorhandler(self, exc, value) sql += b','
qs = '\n'.join([query[:p], ',\n'.join(q), query[e:]]) sql += v
if not PY2: rows += self.execute(sql + postfix)
qs = qs.encode(db.unicode_literal.charset, 'surrogateescape') self.rowcount = rows
r = self._query(qs) return rows
if not self._defer_warnings: self._warning_check()
return r
def callproc(self, procname, args=()): def callproc(self, procname, args=()):
"""Execute stored procedure procname with args """Execute stored procedure procname with args

74
tests/test_cursor.py Normal file
View File

@ -0,0 +1,74 @@
import py.test
import MySQLdb.cursors
from configdb import connection_factory
_conns = []
_tables = []
def connect(**kwargs):
conn = connection_factory(**kwargs)
_conns.append(conn)
return conn
def teardown_function(function):
if _tables:
c = _conns[0]
cur = c.cursor()
for t in _tables:
cur.execute("DROP TABLE %s" % (t,))
cur.close()
del _tables[:]
for c in _conns:
c.close()
del _conns[:]
def test_executemany():
conn = connect()
cursor = conn.cursor()
cursor.execute("create table test (data varchar(10))")
_tables.append("test")
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%s, %s)")
assert m is not None, 'error parse %s'
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)")
assert m is not None, 'error parse %(name)s'
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)")
assert m is not None, 'error parse %(id_name)s'
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update")
assert m is not None, 'error parse %(id_name)s'
assert m.group(3) == ' ON duplicate update', 'group 3 not ON duplicate update, bug in RE_INSERT_VALUES?'
# cursor._executed myst bee "insert into test (data) values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)"
# list args
data = range(10)
cursor.executemany("insert into test (data) values (%s)", data)
assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %s not in one query'
# dict args
data_dict = [{'data': i} for i in range(10)]
cursor.executemany("insert into test (data) values (%(data)s)", data_dict)
assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %(data)s not in one query'
# %% in column set
cursor.execute("""\
CREATE TABLE percent_test (
`A%` INTEGER,
`B%` INTEGER)""")
try:
q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)"
assert MySQLdb.cursors.RE_INSERT_VALUES.match(q) is not None
cursor.executemany(q, [(3, 4), (5, 6)])
assert cursor._executed.endswith(b"(3, 4),(5, 6)"), "executemany with %% not in one query"
finally:
cursor.execute("DROP TABLE IF EXISTS percent_test")