Merge pull request #83 from methane/fix/executemany-double-percent

Port executemany() implementation from PyMySQL
This commit is contained in:
INADA Naoki
2016-05-11 18:41:53 +09:00
3 changed files with 185 additions and 79 deletions

View File

@ -65,7 +65,6 @@ def numeric_part(s):
class Connection(_mysql.connection):
"""MySQL Database Connection Object"""
default_cursor = cursors.Cursor
@ -278,6 +277,9 @@ class Connection(_mysql.connection):
return (cursorclass or self.cursorclass)(self)
def query(self, query):
# Since _mysql releases GIL while querying, we need immutable buffer.
if isinstance(query, bytearray):
query = bytes(query)
if self.waiter is not None:
self.send_query(query)
self.waiter(self.fileno())
@ -353,6 +355,7 @@ class Connection(_mysql.connection):
self.store_result()
self.string_decoder.charset = py_charset
self.unicode_literal.charset = py_charset
self.encoding = py_charset
def set_sql_mode(self, sql_mode):
"""Set the connection sql_mode. See MySQL documentation for

View File

@ -2,45 +2,34 @@
This module implements Cursors of various types for MySQLdb. By
default, MySQLdb uses the Cursor class.
"""
from __future__ import print_function, absolute_import
from functools import partial
import re
import sys
PY2 = sys.version_info[0] == 2
from MySQLdb.compat import unicode
from _mysql_exceptions import (
Warning, Error, InterfaceError, DataError,
DatabaseError, OperationalError, IntegrityError, InternalError,
NotSupportedError, ProgrammingError)
restr = r"""
\s
values
\s*
(
\(
[^()']*
(?:
(?:
(?:\(
# ( - editor highlighting helper
.*
\))
|
'
[^\\']*
(?:\\.[^\\']*)*
'
)
[^()']*
)*
\)
)
"""
insert_values = re.compile(restr, re.S | re.I | re.X)
PY2 = sys.version_info[0] == 2
if PY2:
text_type = unicode
else:
text_type = str
from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \
DatabaseError, OperationalError, IntegrityError, InternalError, \
NotSupportedError, ProgrammingError
#: Regular expression for :meth:`Cursor.executemany`.
#: executemany only suports simple bulk insert.
#: You can use it to load large dataset.
RE_INSERT_VALUES = re.compile(
r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
r"(\s*(?:ON DUPLICATE.*)?)\Z",
re.IGNORECASE | re.DOTALL)
class BaseCursor(object):
@ -60,6 +49,12 @@ class BaseCursor(object):
default number of rows fetchmany() will fetch
"""
#: Max stetement size which :meth:`executemany` generates.
#:
#: Max size of allowed statement is max_allowed_packet - packet_header_size.
#: Default value of max_allowed_packet is 1048576.
max_stmt_length = 64*1024
from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \
DatabaseError, DataError, OperationalError, IntegrityError, \
InternalError, ProgrammingError, NotSupportedError
@ -102,6 +97,32 @@ class BaseCursor(object):
del exc_info
self.close()
def _ensure_bytes(self, x, encoding=None):
if isinstance(x, text_type):
x = x.encode(encoding)
elif isinstance(x, (tuple, list)):
x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
return x
def _escape_args(self, args, conn):
ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
if isinstance(args, (tuple, list)):
if PY2:
args = tuple(map(ensure_bytes, args))
return tuple(conn.literal(arg) for arg in args)
elif isinstance(args, dict):
if PY2:
args = dict((ensure_bytes(key), ensure_bytes(val)) for
(key, val) in args.items())
return dict((key, conn.literal(val)) for (key, val) in args.items())
else:
# If it's not a dictionary let's try escaping it anyways.
# Worst case it will throw a Value error
if PY2:
args = ensure_bytes(args)
return conn.literal(args)
def _check_executed(self):
if not self._executed:
self.errorhandler(self, ProgrammingError, "execute() first")
@ -230,62 +251,70 @@ class BaseCursor(object):
return res
def executemany(self, query, args):
# type: (str, list) -> int
"""Execute a multi-row query.
query -- string, query to execute on server
args
Sequence of sequences or mappings, parameters to use with
query.
Returns long integer rows affected, if any.
:param query: query to execute on server
:param args: Sequence of sequences or mappings. It is used as parameter.
:return: Number of rows affected, if any.
This method improves performance on multiple-row INSERT and
REPLACE. Otherwise it is equivalent to looping over args with
execute().
"""
del self.messages[:]
db = self._get_db()
if not args: return
if PY2 and isinstance(query, unicode):
query = query.encode(db.unicode_literal.charset)
elif not PY2 and isinstance(query, bytes):
query = query.decode(db.unicode_literal.charset)
m = insert_values.search(query)
if not m:
r = 0
for a in args:
r = r + self.execute(query, a)
return r
p = m.start(1)
e = m.end(1)
qv = m.group(1)
try:
q = []
for a in args:
if isinstance(a, dict):
q.append(qv % dict((key, db.literal(item))
for key, item in a.items()))
if not args:
return
m = RE_INSERT_VALUES.match(query)
if m:
q_prefix = m.group(1) % ()
q_values = m.group(2).rstrip()
q_postfix = m.group(3) or ''
assert q_values[0] == '(' and q_values[-1] == ')'
return self._do_execute_many(q_prefix, q_values, q_postfix, args,
self.max_stmt_length,
self._get_db().encoding)
self.rowcount = sum(self.execute(query, arg) for arg in args)
return self.rowcount
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
conn = self._get_db()
escape = self._escape_args
if isinstance(prefix, text_type):
prefix = prefix.encode(encoding)
if PY2 and isinstance(values, text_type):
values = values.encode(encoding)
if isinstance(postfix, text_type):
postfix = postfix.encode(encoding)
sql = bytearray(prefix)
args = iter(args)
v = values % escape(next(args), conn)
if isinstance(v, text_type):
if PY2:
v = v.encode(encoding)
else:
q.append(qv % tuple([db.literal(item) for item in a]))
except TypeError as msg:
if msg.args[0] in ("not enough arguments for format string",
"not all arguments converted"):
self.errorhandler(self, ProgrammingError, msg.args[0])
v = v.encode(encoding, 'surrogateescape')
sql += v
rows = 0
for arg in args:
v = values % escape(arg, conn)
if isinstance(v, text_type):
if PY2:
v = v.encode(encoding)
else:
self.errorhandler(self, TypeError, msg)
except (SystemExit, KeyboardInterrupt):
raise
except:
exc, value = sys.exc_info()[:2]
self.errorhandler(self, exc, value)
qs = '\n'.join([query[:p], ',\n'.join(q), query[e:]])
if not PY2:
qs = qs.encode(db.unicode_literal.charset, 'surrogateescape')
r = self._query(qs)
if not self._defer_warnings: self._warning_check()
return r
v = v.encode(encoding, 'surrogateescape')
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
rows += self.execute(sql + postfix)
sql = bytearray(prefix)
else:
sql += b','
sql += v
rows += self.execute(sql + postfix)
self.rowcount = rows
return rows
def callproc(self, procname, args=()):
"""Execute stored procedure procname with args

74
tests/test_cursor.py Normal file
View File

@ -0,0 +1,74 @@
import py.test
import MySQLdb.cursors
from configdb import connection_factory
_conns = []
_tables = []
def connect(**kwargs):
conn = connection_factory(**kwargs)
_conns.append(conn)
return conn
def teardown_function(function):
if _tables:
c = _conns[0]
cur = c.cursor()
for t in _tables:
cur.execute("DROP TABLE %s" % (t,))
cur.close()
del _tables[:]
for c in _conns:
c.close()
del _conns[:]
def test_executemany():
conn = connect()
cursor = conn.cursor()
cursor.execute("create table test (data varchar(10))")
_tables.append("test")
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%s, %s)")
assert m is not None, 'error parse %s'
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)")
assert m is not None, 'error parse %(name)s'
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)")
assert m is not None, 'error parse %(id_name)s'
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update")
assert m is not None, 'error parse %(id_name)s'
assert m.group(3) == ' ON duplicate update', 'group 3 not ON duplicate update, bug in RE_INSERT_VALUES?'
# cursor._executed myst bee "insert into test (data) values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)"
# list args
data = range(10)
cursor.executemany("insert into test (data) values (%s)", data)
assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %s not in one query'
# dict args
data_dict = [{'data': i} for i in range(10)]
cursor.executemany("insert into test (data) values (%(data)s)", data_dict)
assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %(data)s not in one query'
# %% in column set
cursor.execute("""\
CREATE TABLE percent_test (
`A%` INTEGER,
`B%` INTEGER)""")
try:
q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)"
assert MySQLdb.cursors.RE_INSERT_VALUES.match(q) is not None
cursor.executemany(q, [(3, 4), (5, 6)])
assert cursor._executed.endswith(b"(3, 4),(5, 6)"), "executemany with %% not in one query"
finally:
cursor.execute("DROP TABLE IF EXISTS percent_test")