mirror of
https://github.com/PyMySQL/mysqlclient.git
synced 2025-08-15 11:10:58 +08:00
Merge pull request #83 from methane/fix/executemany-double-percent
Port executemany() implementation from PyMySQL
This commit is contained in:
@ -65,7 +65,6 @@ def numeric_part(s):
|
|||||||
|
|
||||||
|
|
||||||
class Connection(_mysql.connection):
|
class Connection(_mysql.connection):
|
||||||
|
|
||||||
"""MySQL Database Connection Object"""
|
"""MySQL Database Connection Object"""
|
||||||
|
|
||||||
default_cursor = cursors.Cursor
|
default_cursor = cursors.Cursor
|
||||||
@ -278,6 +277,9 @@ class Connection(_mysql.connection):
|
|||||||
return (cursorclass or self.cursorclass)(self)
|
return (cursorclass or self.cursorclass)(self)
|
||||||
|
|
||||||
def query(self, query):
|
def query(self, query):
|
||||||
|
# Since _mysql releases GIL while querying, we need immutable buffer.
|
||||||
|
if isinstance(query, bytearray):
|
||||||
|
query = bytes(query)
|
||||||
if self.waiter is not None:
|
if self.waiter is not None:
|
||||||
self.send_query(query)
|
self.send_query(query)
|
||||||
self.waiter(self.fileno())
|
self.waiter(self.fileno())
|
||||||
@ -353,6 +355,7 @@ class Connection(_mysql.connection):
|
|||||||
self.store_result()
|
self.store_result()
|
||||||
self.string_decoder.charset = py_charset
|
self.string_decoder.charset = py_charset
|
||||||
self.unicode_literal.charset = py_charset
|
self.unicode_literal.charset = py_charset
|
||||||
|
self.encoding = py_charset
|
||||||
|
|
||||||
def set_sql_mode(self, sql_mode):
|
def set_sql_mode(self, sql_mode):
|
||||||
"""Set the connection sql_mode. See MySQL documentation for
|
"""Set the connection sql_mode. See MySQL documentation for
|
||||||
|
@ -2,45 +2,34 @@
|
|||||||
|
|
||||||
This module implements Cursors of various types for MySQLdb. By
|
This module implements Cursors of various types for MySQLdb. By
|
||||||
default, MySQLdb uses the Cursor class.
|
default, MySQLdb uses the Cursor class.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from __future__ import print_function, absolute_import
|
||||||
|
from functools import partial
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
PY2 = sys.version_info[0] == 2
|
|
||||||
|
|
||||||
from MySQLdb.compat import unicode
|
from MySQLdb.compat import unicode
|
||||||
|
from _mysql_exceptions import (
|
||||||
|
Warning, Error, InterfaceError, DataError,
|
||||||
|
DatabaseError, OperationalError, IntegrityError, InternalError,
|
||||||
|
NotSupportedError, ProgrammingError)
|
||||||
|
|
||||||
restr = r"""
|
|
||||||
\s
|
|
||||||
values
|
|
||||||
\s*
|
|
||||||
(
|
|
||||||
\(
|
|
||||||
[^()']*
|
|
||||||
(?:
|
|
||||||
(?:
|
|
||||||
(?:\(
|
|
||||||
# ( - editor highlighting helper
|
|
||||||
.*
|
|
||||||
\))
|
|
||||||
|
|
|
||||||
'
|
|
||||||
[^\\']*
|
|
||||||
(?:\\.[^\\']*)*
|
|
||||||
'
|
|
||||||
)
|
|
||||||
[^()']*
|
|
||||||
)*
|
|
||||||
\)
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
insert_values = re.compile(restr, re.S | re.I | re.X)
|
PY2 = sys.version_info[0] == 2
|
||||||
|
if PY2:
|
||||||
|
text_type = unicode
|
||||||
|
else:
|
||||||
|
text_type = str
|
||||||
|
|
||||||
from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \
|
|
||||||
DatabaseError, OperationalError, IntegrityError, InternalError, \
|
#: Regular expression for :meth:`Cursor.executemany`.
|
||||||
NotSupportedError, ProgrammingError
|
#: executemany only suports simple bulk insert.
|
||||||
|
#: You can use it to load large dataset.
|
||||||
|
RE_INSERT_VALUES = re.compile(
|
||||||
|
r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
|
||||||
|
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
|
||||||
|
r"(\s*(?:ON DUPLICATE.*)?)\Z",
|
||||||
|
re.IGNORECASE | re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
class BaseCursor(object):
|
class BaseCursor(object):
|
||||||
@ -60,6 +49,12 @@ class BaseCursor(object):
|
|||||||
default number of rows fetchmany() will fetch
|
default number of rows fetchmany() will fetch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
#: Max stetement size which :meth:`executemany` generates.
|
||||||
|
#:
|
||||||
|
#: Max size of allowed statement is max_allowed_packet - packet_header_size.
|
||||||
|
#: Default value of max_allowed_packet is 1048576.
|
||||||
|
max_stmt_length = 64*1024
|
||||||
|
|
||||||
from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \
|
from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \
|
||||||
DatabaseError, DataError, OperationalError, IntegrityError, \
|
DatabaseError, DataError, OperationalError, IntegrityError, \
|
||||||
InternalError, ProgrammingError, NotSupportedError
|
InternalError, ProgrammingError, NotSupportedError
|
||||||
@ -102,6 +97,32 @@ class BaseCursor(object):
|
|||||||
del exc_info
|
del exc_info
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
def _ensure_bytes(self, x, encoding=None):
|
||||||
|
if isinstance(x, text_type):
|
||||||
|
x = x.encode(encoding)
|
||||||
|
elif isinstance(x, (tuple, list)):
|
||||||
|
x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _escape_args(self, args, conn):
|
||||||
|
ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
|
||||||
|
|
||||||
|
if isinstance(args, (tuple, list)):
|
||||||
|
if PY2:
|
||||||
|
args = tuple(map(ensure_bytes, args))
|
||||||
|
return tuple(conn.literal(arg) for arg in args)
|
||||||
|
elif isinstance(args, dict):
|
||||||
|
if PY2:
|
||||||
|
args = dict((ensure_bytes(key), ensure_bytes(val)) for
|
||||||
|
(key, val) in args.items())
|
||||||
|
return dict((key, conn.literal(val)) for (key, val) in args.items())
|
||||||
|
else:
|
||||||
|
# If it's not a dictionary let's try escaping it anyways.
|
||||||
|
# Worst case it will throw a Value error
|
||||||
|
if PY2:
|
||||||
|
args = ensure_bytes(args)
|
||||||
|
return conn.literal(args)
|
||||||
|
|
||||||
def _check_executed(self):
|
def _check_executed(self):
|
||||||
if not self._executed:
|
if not self._executed:
|
||||||
self.errorhandler(self, ProgrammingError, "execute() first")
|
self.errorhandler(self, ProgrammingError, "execute() first")
|
||||||
@ -230,62 +251,70 @@ class BaseCursor(object):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def executemany(self, query, args):
|
def executemany(self, query, args):
|
||||||
|
# type: (str, list) -> int
|
||||||
"""Execute a multi-row query.
|
"""Execute a multi-row query.
|
||||||
|
|
||||||
query -- string, query to execute on server
|
:param query: query to execute on server
|
||||||
|
:param args: Sequence of sequences or mappings. It is used as parameter.
|
||||||
args
|
:return: Number of rows affected, if any.
|
||||||
|
|
||||||
Sequence of sequences or mappings, parameters to use with
|
|
||||||
query.
|
|
||||||
|
|
||||||
Returns long integer rows affected, if any.
|
|
||||||
|
|
||||||
This method improves performance on multiple-row INSERT and
|
This method improves performance on multiple-row INSERT and
|
||||||
REPLACE. Otherwise it is equivalent to looping over args with
|
REPLACE. Otherwise it is equivalent to looping over args with
|
||||||
execute().
|
execute().
|
||||||
"""
|
"""
|
||||||
del self.messages[:]
|
del self.messages[:]
|
||||||
db = self._get_db()
|
|
||||||
if not args: return
|
if not args:
|
||||||
if PY2 and isinstance(query, unicode):
|
return
|
||||||
query = query.encode(db.unicode_literal.charset)
|
|
||||||
elif not PY2 and isinstance(query, bytes):
|
m = RE_INSERT_VALUES.match(query)
|
||||||
query = query.decode(db.unicode_literal.charset)
|
if m:
|
||||||
m = insert_values.search(query)
|
q_prefix = m.group(1) % ()
|
||||||
if not m:
|
q_values = m.group(2).rstrip()
|
||||||
r = 0
|
q_postfix = m.group(3) or ''
|
||||||
for a in args:
|
assert q_values[0] == '(' and q_values[-1] == ')'
|
||||||
r = r + self.execute(query, a)
|
return self._do_execute_many(q_prefix, q_values, q_postfix, args,
|
||||||
return r
|
self.max_stmt_length,
|
||||||
p = m.start(1)
|
self._get_db().encoding)
|
||||||
e = m.end(1)
|
|
||||||
qv = m.group(1)
|
self.rowcount = sum(self.execute(query, arg) for arg in args)
|
||||||
try:
|
return self.rowcount
|
||||||
q = []
|
|
||||||
for a in args:
|
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
|
||||||
if isinstance(a, dict):
|
conn = self._get_db()
|
||||||
q.append(qv % dict((key, db.literal(item))
|
escape = self._escape_args
|
||||||
for key, item in a.items()))
|
if isinstance(prefix, text_type):
|
||||||
|
prefix = prefix.encode(encoding)
|
||||||
|
if PY2 and isinstance(values, text_type):
|
||||||
|
values = values.encode(encoding)
|
||||||
|
if isinstance(postfix, text_type):
|
||||||
|
postfix = postfix.encode(encoding)
|
||||||
|
sql = bytearray(prefix)
|
||||||
|
args = iter(args)
|
||||||
|
v = values % escape(next(args), conn)
|
||||||
|
if isinstance(v, text_type):
|
||||||
|
if PY2:
|
||||||
|
v = v.encode(encoding)
|
||||||
else:
|
else:
|
||||||
q.append(qv % tuple([db.literal(item) for item in a]))
|
v = v.encode(encoding, 'surrogateescape')
|
||||||
except TypeError as msg:
|
sql += v
|
||||||
if msg.args[0] in ("not enough arguments for format string",
|
rows = 0
|
||||||
"not all arguments converted"):
|
for arg in args:
|
||||||
self.errorhandler(self, ProgrammingError, msg.args[0])
|
v = values % escape(arg, conn)
|
||||||
|
if isinstance(v, text_type):
|
||||||
|
if PY2:
|
||||||
|
v = v.encode(encoding)
|
||||||
else:
|
else:
|
||||||
self.errorhandler(self, TypeError, msg)
|
v = v.encode(encoding, 'surrogateescape')
|
||||||
except (SystemExit, KeyboardInterrupt):
|
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
|
||||||
raise
|
rows += self.execute(sql + postfix)
|
||||||
except:
|
sql = bytearray(prefix)
|
||||||
exc, value = sys.exc_info()[:2]
|
else:
|
||||||
self.errorhandler(self, exc, value)
|
sql += b','
|
||||||
qs = '\n'.join([query[:p], ',\n'.join(q), query[e:]])
|
sql += v
|
||||||
if not PY2:
|
rows += self.execute(sql + postfix)
|
||||||
qs = qs.encode(db.unicode_literal.charset, 'surrogateescape')
|
self.rowcount = rows
|
||||||
r = self._query(qs)
|
return rows
|
||||||
if not self._defer_warnings: self._warning_check()
|
|
||||||
return r
|
|
||||||
|
|
||||||
def callproc(self, procname, args=()):
|
def callproc(self, procname, args=()):
|
||||||
"""Execute stored procedure procname with args
|
"""Execute stored procedure procname with args
|
||||||
|
74
tests/test_cursor.py
Normal file
74
tests/test_cursor.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import py.test
|
||||||
|
import MySQLdb.cursors
|
||||||
|
from configdb import connection_factory
|
||||||
|
|
||||||
|
|
||||||
|
_conns = []
|
||||||
|
_tables = []
|
||||||
|
|
||||||
|
def connect(**kwargs):
|
||||||
|
conn = connection_factory(**kwargs)
|
||||||
|
_conns.append(conn)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def teardown_function(function):
|
||||||
|
if _tables:
|
||||||
|
c = _conns[0]
|
||||||
|
cur = c.cursor()
|
||||||
|
for t in _tables:
|
||||||
|
cur.execute("DROP TABLE %s" % (t,))
|
||||||
|
cur.close()
|
||||||
|
del _tables[:]
|
||||||
|
|
||||||
|
for c in _conns:
|
||||||
|
c.close()
|
||||||
|
del _conns[:]
|
||||||
|
|
||||||
|
|
||||||
|
def test_executemany():
|
||||||
|
conn = connect()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute("create table test (data varchar(10))")
|
||||||
|
_tables.append("test")
|
||||||
|
|
||||||
|
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%s, %s)")
|
||||||
|
assert m is not None, 'error parse %s'
|
||||||
|
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
|
||||||
|
|
||||||
|
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)")
|
||||||
|
assert m is not None, 'error parse %(name)s'
|
||||||
|
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
|
||||||
|
|
||||||
|
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)")
|
||||||
|
assert m is not None, 'error parse %(id_name)s'
|
||||||
|
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
|
||||||
|
|
||||||
|
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update")
|
||||||
|
assert m is not None, 'error parse %(id_name)s'
|
||||||
|
assert m.group(3) == ' ON duplicate update', 'group 3 not ON duplicate update, bug in RE_INSERT_VALUES?'
|
||||||
|
|
||||||
|
# cursor._executed myst bee "insert into test (data) values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)"
|
||||||
|
# list args
|
||||||
|
data = range(10)
|
||||||
|
cursor.executemany("insert into test (data) values (%s)", data)
|
||||||
|
assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %s not in one query'
|
||||||
|
|
||||||
|
# dict args
|
||||||
|
data_dict = [{'data': i} for i in range(10)]
|
||||||
|
cursor.executemany("insert into test (data) values (%(data)s)", data_dict)
|
||||||
|
assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %(data)s not in one query'
|
||||||
|
|
||||||
|
# %% in column set
|
||||||
|
cursor.execute("""\
|
||||||
|
CREATE TABLE percent_test (
|
||||||
|
`A%` INTEGER,
|
||||||
|
`B%` INTEGER)""")
|
||||||
|
try:
|
||||||
|
q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)"
|
||||||
|
assert MySQLdb.cursors.RE_INSERT_VALUES.match(q) is not None
|
||||||
|
cursor.executemany(q, [(3, 4), (5, 6)])
|
||||||
|
assert cursor._executed.endswith(b"(3, 4),(5, 6)"), "executemany with %% not in one query"
|
||||||
|
finally:
|
||||||
|
cursor.execute("DROP TABLE IF EXISTS percent_test")
|
Reference in New Issue
Block a user