Connection.literal() always returns str instance.

This commit is contained in:
INADA Naoki
2014-04-17 22:36:01 +09:00
parent 382fb9f9b3
commit c66b43cc22
2 changed files with 17 additions and 17 deletions

View File

@ -13,6 +13,9 @@ from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \
NotSupportedError, ProgrammingError
import _mysql
import re
import sys
PY2 = sys.version_info[0] == 2
def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
@ -280,7 +283,10 @@ class Connection(_mysql.connection):
applications.
"""
return self.escape(o, self.encoders)
s = self.escape(o, self.encoders)
if not PY2 and isinstance(s, bytes):
return s.decode('ascii', 'surrogateescape')
return s
def begin(self):
"""Explicitly begin a connection. Non-standard.

View File

@ -11,7 +11,7 @@ PY2 = sys.version_info[0] == 2
from MySQLdb.compat import unicode
restr = br"""
restr = r"""
\s
values
\s*
@ -179,23 +179,12 @@ class BaseCursor(object):
db = self._get_db()
if PY2 and isinstance(query, unicode):
query = query.encode(db.unicode_literal.charset)
else:
def decode(x):
if isinstance(x, bytes):
x = x.decode('ascii', 'surrogateescape')
return x
if args is not None:
if isinstance(args, dict):
if PY2:
args = dict((key, db.literal(item)) for key, item in args.iteritems())
else:
args = dict((key, decode(db.literal(item))) for key, item in args.items())
else:
if PY2:
args = tuple(map(db.literal, args))
else:
args = tuple([decode(db.literal(x)) for x in args])
if not PY2 and isinstance(query, bytes):
query = query.decode(db.unicode_literal.charset)
query = query % args
@ -246,8 +235,10 @@ class BaseCursor(object):
del self.messages[:]
db = self._get_db()
if not args: return
if isinstance(query, unicode):
if PY2 and isinstance(query, unicode):
query = query.encode(db.unicode_literal.charset)
elif not PY2 and isinstance(query, bytes):
query = query.decode(db.unicode_literal.charset)
m = insert_values.search(query)
if not m:
r = 0
@ -277,7 +268,10 @@ class BaseCursor(object):
exc, value, tb = sys.exc_info()
del tb
self.errorhandler(self, exc, value)
r = self._query('\n'.join([query[:p], ',\n'.join(q), query[e:]]))
qs = '\n'.join([query[:p], ',\n'.join(q), query[e:]])
if not PY2:
qs = qs.encode(db.unicode_literal.charset, 'surrogateescape')
r = self._query(qs)
if not self._defer_warnings: self._warning_check()
return r