From 49e401b3bc2c123b36949baeb409b91dcbdf85a6 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Wed, 11 May 2016 11:25:53 +0900 Subject: [PATCH 1/2] Port executemany() implementation from PyMySQL --- MySQLdb/connections.py | 4 + MySQLdb/cursors.py | 185 ++++++++++++++++++++++++----------------- tests/test_cursor.py | 74 +++++++++++++++++ 3 files changed, 185 insertions(+), 78 deletions(-) create mode 100644 tests/test_cursor.py diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index f51106c..d97e918 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -278,6 +278,9 @@ class Connection(_mysql.connection): return (cursorclass or self.cursorclass)(self) def query(self, query): + # Since _mysql releases GIL while querying, we need immutable buffer. + if isinstance(query, bytearray): + query = bytes(query) if self.waiter is not None: self.send_query(query) self.waiter(self.fileno()) @@ -353,6 +356,7 @@ class Connection(_mysql.connection): self.store_result() self.string_decoder.charset = py_charset self.unicode_literal.charset = py_charset + self.encoding = py_charset def set_sql_mode(self, sql_mode): """Set the connection sql_mode. See MySQL documentation for diff --git a/MySQLdb/cursors.py b/MySQLdb/cursors.py index 661ce35..6260a02 100644 --- a/MySQLdb/cursors.py +++ b/MySQLdb/cursors.py @@ -2,45 +2,34 @@ This module implements Cursors of various types for MySQLdb. By default, MySQLdb uses the Cursor class. - """ - +from __future__ import print_function, absolute_import +from functools import partial import re import sys -PY2 = sys.version_info[0] == 2 from MySQLdb.compat import unicode +from _mysql_exceptions import ( + Warning, Error, InterfaceError, DataError, + DatabaseError, OperationalError, IntegrityError, InternalError, + NotSupportedError, ProgrammingError) -restr = r""" - \s - values - \s* - ( - \( - [^()']* - (?: - (?: - (?:\( - # ( - editor highlighting helper - .* - \)) - | - ' - [^\\']* - (?:\\.[^\\']*)* - ' - ) - [^()']* - )* - \) - ) -""" -insert_values = re.compile(restr, re.S | re.I | re.X) +PY2 = sys.version_info[0] == 2 +if PY2: + text_type = unicode +else: + text_type = str -from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \ - DatabaseError, OperationalError, IntegrityError, InternalError, \ - NotSupportedError, ProgrammingError + +#: Regular expression for :meth:`Cursor.executemany`. +#: executemany only suports simple bulk insert. +#: You can use it to load large dataset. +RE_INSERT_VALUES = re.compile( + r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" + + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + + r"(\s*(?:ON DUPLICATE.*)?)\Z", + re.IGNORECASE | re.DOTALL) class BaseCursor(object): @@ -60,6 +49,12 @@ class BaseCursor(object): default number of rows fetchmany() will fetch """ + #: Max stetement size which :meth:`executemany` generates. + #: + #: Max size of allowed statement is max_allowed_packet - packet_header_size. + #: Default value of max_allowed_packet is 1048576. + max_stmt_length = 64*1024 + from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \ DatabaseError, DataError, OperationalError, IntegrityError, \ InternalError, ProgrammingError, NotSupportedError @@ -102,6 +97,32 @@ class BaseCursor(object): del exc_info self.close() + def _ensure_bytes(self, x, encoding=None): + if isinstance(x, text_type): + x = x.encode(encoding) + elif isinstance(x, (tuple, list)): + x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x) + return x + + def _escape_args(self, args, conn): + ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding) + + if isinstance(args, (tuple, list)): + if PY2: + args = tuple(map(ensure_bytes, args)) + return tuple(conn.escape(arg) for arg in args) + elif isinstance(args, dict): + if PY2: + args = dict((ensure_bytes(key), ensure_bytes(val)) for + (key, val) in args.items()) + return dict((key, conn.escape(val)) for (key, val) in args.items()) + else: + # If it's not a dictionary let's try escaping it anyways. + # Worst case it will throw a Value error + if PY2: + args = ensure_bytes(args) + return conn.escape(args) + def _check_executed(self): if not self._executed: self.errorhandler(self, ProgrammingError, "execute() first") @@ -230,62 +251,70 @@ class BaseCursor(object): return res def executemany(self, query, args): + # type: (str, list) -> int """Execute a multi-row query. - query -- string, query to execute on server - - args - - Sequence of sequences or mappings, parameters to use with - query. - - Returns long integer rows affected, if any. + :param query: query to execute on server + :param args: Sequence of sequences or mappings. It is used as parameter. + :return: Number of rows affected, if any. This method improves performance on multiple-row INSERT and REPLACE. Otherwise it is equivalent to looping over args with execute(). """ del self.messages[:] - db = self._get_db() - if not args: return - if PY2 and isinstance(query, unicode): - query = query.encode(db.unicode_literal.charset) - elif not PY2 and isinstance(query, bytes): - query = query.decode(db.unicode_literal.charset) - m = insert_values.search(query) - if not m: - r = 0 - for a in args: - r = r + self.execute(query, a) - return r - p = m.start(1) - e = m.end(1) - qv = m.group(1) - try: - q = [] - for a in args: - if isinstance(a, dict): - q.append(qv % dict((key, db.literal(item)) - for key, item in a.items())) - else: - q.append(qv % tuple([db.literal(item) for item in a])) - except TypeError as msg: - if msg.args[0] in ("not enough arguments for format string", - "not all arguments converted"): - self.errorhandler(self, ProgrammingError, msg.args[0]) + + if not args: + return + + m = RE_INSERT_VALUES.match(query) + if m: + q_prefix = m.group(1) % () + q_values = m.group(2).rstrip() + q_postfix = m.group(3) or '' + assert q_values[0] == '(' and q_values[-1] == ')' + return self._do_execute_many(q_prefix, q_values, q_postfix, args, + self.max_stmt_length, + self._get_db().encoding) + + self.rowcount = sum(self.execute(query, arg) for arg in args) + return self.rowcount + + def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding): + conn = self._get_db() + escape = self._escape_args + if isinstance(prefix, text_type): + prefix = prefix.encode(encoding) + if PY2 and isinstance(values, text_type): + values = values.encode(encoding) + if isinstance(postfix, text_type): + postfix = postfix.encode(encoding) + sql = bytearray(prefix) + args = iter(args) + v = values % escape(next(args), conn) + if isinstance(v, text_type): + if PY2: + v = v.encode(encoding) else: - self.errorhandler(self, TypeError, msg) - except (SystemExit, KeyboardInterrupt): - raise - except: - exc, value = sys.exc_info()[:2] - self.errorhandler(self, exc, value) - qs = '\n'.join([query[:p], ',\n'.join(q), query[e:]]) - if not PY2: - qs = qs.encode(db.unicode_literal.charset, 'surrogateescape') - r = self._query(qs) - if not self._defer_warnings: self._warning_check() - return r + v = v.encode(encoding, 'surrogateescape') + sql += v + rows = 0 + for arg in args: + v = values % escape(arg, conn) + if isinstance(v, text_type): + if PY2: + v = v.encode(encoding) + else: + v = v.encode(encoding, 'surrogateescape') + if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: + rows += self.execute(sql + postfix) + sql = bytearray(prefix) + else: + sql += b',' + sql += v + rows += self.execute(sql + postfix) + self.rowcount = rows + return rows def callproc(self, procname, args=()): """Execute stored procedure procname with args diff --git a/tests/test_cursor.py b/tests/test_cursor.py new file mode 100644 index 0000000..04bb1c4 --- /dev/null +++ b/tests/test_cursor.py @@ -0,0 +1,74 @@ +import py.test +import MySQLdb.cursors +from configdb import connection_factory + + +_conns = [] +_tables = [] + +def connect(**kwargs): + conn = connection_factory(**kwargs) + _conns.append(conn) + return conn + + +def teardown_function(function): + if _tables: + c = _conns[0] + cur = c.cursor() + for t in _tables: + cur.execute("DROP TABLE %s" % (t,)) + cur.close() + del _tables[:] + + for c in _conns: + c.close() + del _conns[:] + + +def test_executemany(): + conn = connect() + cursor = conn.cursor() + + cursor.execute("create table test (data varchar(10))") + _tables.append("test") + + m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%s, %s)") + assert m is not None, 'error parse %s' + assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?' + + m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)") + assert m is not None, 'error parse %(name)s' + assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?' + + m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)") + assert m is not None, 'error parse %(id_name)s' + assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?' + + m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update") + assert m is not None, 'error parse %(id_name)s' + assert m.group(3) == ' ON duplicate update', 'group 3 not ON duplicate update, bug in RE_INSERT_VALUES?' + + # cursor._executed myst bee "insert into test (data) values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)" + # list args + data = range(10) + cursor.executemany("insert into test (data) values (%s)", data) + assert cursor._executed.endswith(",(7),(8),(9)"), 'execute many with %s not in one query' + + # dict args + data_dict = [{'data': i} for i in range(10)] + cursor.executemany("insert into test (data) values (%(data)s)", data_dict) + assert cursor._executed.endswith(",(7),(8),(9)"), 'execute many with %(data)s not in one query' + + # %% in column set + cursor.execute("""\ + CREATE TABLE percent_test ( + `A%` INTEGER, + `B%` INTEGER)""") + try: + q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)" + assert MySQLdb.cursors.RE_INSERT_VALUES.match(q) is not None + cursor.executemany(q, [(3, 4), (5, 6)]) + assert cursor._executed.endswith("(3, 4),(5, 6)"), "executemany with %% not in one query" + finally: + cursor.execute("DROP TABLE IF EXISTS percent_test") From 57dd34dc10e3f8ec7d860dc8bc8e6baccb571b60 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Wed, 11 May 2016 15:24:36 +0900 Subject: [PATCH 2/2] fixup --- MySQLdb/connections.py | 1 - MySQLdb/cursors.py | 6 +++--- tests/test_cursor.py | 6 +++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index d97e918..1f69f5c 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -65,7 +65,6 @@ def numeric_part(s): class Connection(_mysql.connection): - """MySQL Database Connection Object""" default_cursor = cursors.Cursor diff --git a/MySQLdb/cursors.py b/MySQLdb/cursors.py index 6260a02..1e0a3f9 100644 --- a/MySQLdb/cursors.py +++ b/MySQLdb/cursors.py @@ -110,18 +110,18 @@ class BaseCursor(object): if isinstance(args, (tuple, list)): if PY2: args = tuple(map(ensure_bytes, args)) - return tuple(conn.escape(arg) for arg in args) + return tuple(conn.literal(arg) for arg in args) elif isinstance(args, dict): if PY2: args = dict((ensure_bytes(key), ensure_bytes(val)) for (key, val) in args.items()) - return dict((key, conn.escape(val)) for (key, val) in args.items()) + return dict((key, conn.literal(val)) for (key, val) in args.items()) else: # If it's not a dictionary let's try escaping it anyways. # Worst case it will throw a Value error if PY2: args = ensure_bytes(args) - return conn.escape(args) + return conn.literal(args) def _check_executed(self): if not self._executed: diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 04bb1c4..bfdcb33 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -53,12 +53,12 @@ def test_executemany(): # list args data = range(10) cursor.executemany("insert into test (data) values (%s)", data) - assert cursor._executed.endswith(",(7),(8),(9)"), 'execute many with %s not in one query' + assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %s not in one query' # dict args data_dict = [{'data': i} for i in range(10)] cursor.executemany("insert into test (data) values (%(data)s)", data_dict) - assert cursor._executed.endswith(",(7),(8),(9)"), 'execute many with %(data)s not in one query' + assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %(data)s not in one query' # %% in column set cursor.execute("""\ @@ -69,6 +69,6 @@ def test_executemany(): q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)" assert MySQLdb.cursors.RE_INSERT_VALUES.match(q) is not None cursor.executemany(q, [(3, 4), (5, 6)]) - assert cursor._executed.endswith("(3, 4),(5, 6)"), "executemany with %% not in one query" + assert cursor._executed.endswith(b"(3, 4),(5, 6)"), "executemany with %% not in one query" finally: cursor.execute("DROP TABLE IF EXISTS percent_test")