From c66b43cc2200ece4a32485420faf3f4114f0a51c Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 17 Apr 2014 22:36:01 +0900 Subject: [PATCH] Connection.literal() always returns `str` instance. --- MySQLdb/connections.py | 8 +++++++- MySQLdb/cursors.py | 26 ++++++++++---------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index a7adc99..b42817f 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -13,6 +13,9 @@ from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \ NotSupportedError, ProgrammingError import _mysql import re +import sys + +PY2 = sys.version_info[0] == 2 def defaulterrorhandler(connection, cursor, errorclass, errorvalue): @@ -280,7 +283,10 @@ class Connection(_mysql.connection): applications. """ - return self.escape(o, self.encoders) + s = self.escape(o, self.encoders) + if not PY2 and isinstance(s, bytes): + return s.decode('ascii', 'surrogateescape') + return s def begin(self): """Explicitly begin a connection. Non-standard. diff --git a/MySQLdb/cursors.py b/MySQLdb/cursors.py index f63cbee..f179306 100644 --- a/MySQLdb/cursors.py +++ b/MySQLdb/cursors.py @@ -11,7 +11,7 @@ PY2 = sys.version_info[0] == 2 from MySQLdb.compat import unicode -restr = br""" +restr = r""" \s values \s* @@ -179,23 +179,12 @@ class BaseCursor(object): db = self._get_db() if PY2 and isinstance(query, unicode): query = query.encode(db.unicode_literal.charset) - else: - def decode(x): - if isinstance(x, bytes): - x = x.decode('ascii', 'surrogateescape') - return x if args is not None: if isinstance(args, dict): - if PY2: - args = dict((key, db.literal(item)) for key, item in args.iteritems()) - else: - args = dict((key, decode(db.literal(item))) for key, item in args.items()) + args = dict((key, db.literal(item)) for key, item in args.iteritems()) else: - if PY2: - args = tuple(map(db.literal, args)) - else: - args = tuple([decode(db.literal(x)) for x in args]) + args = tuple(map(db.literal, args)) if not PY2 and isinstance(query, bytes): query = query.decode(db.unicode_literal.charset) query = query % args @@ -246,8 +235,10 @@ class BaseCursor(object): del self.messages[:] db = self._get_db() if not args: return - if isinstance(query, unicode): + if PY2 and isinstance(query, unicode): query = query.encode(db.unicode_literal.charset) + elif not PY2 and isinstance(query, bytes): + query = query.decode(db.unicode_literal.charset) m = insert_values.search(query) if not m: r = 0 @@ -277,7 +268,10 @@ class BaseCursor(object): exc, value, tb = sys.exc_info() del tb self.errorhandler(self, exc, value) - r = self._query('\n'.join([query[:p], ',\n'.join(q), query[e:]])) + qs = '\n'.join([query[:p], ',\n'.join(q), query[e:]]) + if not PY2: + qs = qs.encode(db.unicode_literal.charset, 'surrogateescape') + r = self._query(qs) if not self._defer_warnings: self._warning_check() return r