Random fixes.

This commit is contained in:
INADA Naoki
2014-04-17 22:24:46 +09:00
parent 9cdef719a6
commit 382fb9f9b3
3 changed files with 48 additions and 49 deletions

View File

@ -73,7 +73,7 @@ def test_DBAPISet_set_inequality_membership():
assert FIELD_TYPE.DATE != STRING assert FIELD_TYPE.DATE != STRING
def Binary(x): def Binary(x):
return str(x) return bytes(x)
def Connect(*args, **kwargs): def Connect(*args, **kwargs):
"""Factory function for connections.Connection.""" """Factory function for connections.Connection."""

View File

@ -38,13 +38,15 @@ from MySQLdb.times import *
try: try:
from types import IntType, LongType, FloatType, NoneType, TupleType, ListType, DictType, InstanceType, \ from types import IntType, LongType, FloatType, NoneType, TupleType, ListType, DictType, InstanceType, \
StringType, UnicodeType, ObjectType, BooleanType, ClassType, TypeType ObjectType, BooleanType
PY2 = True
except ImportError: except ImportError:
# Python 3 # Python 3
long = int long = int
IntType, LongType, FloatType, NoneType = int, long, float, type(None) IntType, LongType, FloatType, NoneType = int, long, float, type(None)
TupleType, ListType, DictType, InstanceType = tuple, list, dict, None TupleType, ListType, DictType, InstanceType = tuple, list, dict, None
StringType, UnicodeType, ObjectType, BooleanType = bytes, str, object, bool ObjectType, BooleanType = object, bool
PY2 = False
import array import array
@ -95,34 +97,6 @@ def Thing2Literal(o, d):
return string_literal(o, d) return string_literal(o, d)
def Instance2Str(o, d):
"""
Convert an Instance to a string representation. If the __str__()
method produces acceptable output, then you don't need to add the
class to conversions; it will be handled by the default
converter. If the exact class is not found in d, it will use the
first class it can find for which o is an instance.
"""
if o.__class__ in d:
return d[o.__class__](o, d)
cl = filter(lambda x,o=o:
type(x) is ClassType
and isinstance(o, x), d.keys())
if not cl:
cl = filter(lambda x,o=o:
type(x) is TypeType
and isinstance(o, x)
and d[x] is not Instance2Str,
d.keys())
if not cl:
return d[StringType](o,d)
d[o.__class__] = d[cl[0]]
return d[cl[0]](o, d)
def char_array(s): def char_array(s):
return array.array('c', s) return array.array('c', s)
@ -140,14 +114,12 @@ conversions = {
TupleType: quote_tuple, TupleType: quote_tuple,
ListType: quote_tuple, ListType: quote_tuple,
DictType: escape_dict, DictType: escape_dict,
InstanceType: Instance2Str,
ArrayType: array2Str, ArrayType: array2Str,
StringType: Thing2Literal, # default
UnicodeType: Unicode2Str,
ObjectType: Instance2Str,
BooleanType: Bool2Str, BooleanType: Bool2Str,
Date: Thing2Literal,
DateTimeType: DateTime2literal, DateTimeType: DateTime2literal,
DateTimeDeltaType: DateTimeDelta2literal, DateTimeDeltaType: DateTimeDelta2literal,
str: str, # default
set: Set2Str, set: Set2Str,
FIELD_TYPE.TINY: int, FIELD_TYPE.TINY: int,
FIELD_TYPE.SHORT: int, FIELD_TYPE.SHORT: int,
@ -165,18 +137,22 @@ conversions = {
FIELD_TYPE.TIME: TimeDelta_or_None, FIELD_TYPE.TIME: TimeDelta_or_None,
FIELD_TYPE.DATE: Date_or_None, FIELD_TYPE.DATE: Date_or_None,
FIELD_TYPE.BLOB: [ FIELD_TYPE.BLOB: [
(FLAG.BINARY, str), (FLAG.BINARY, bytes),
], ],
FIELD_TYPE.STRING: [ FIELD_TYPE.STRING: [
(FLAG.BINARY, str), (FLAG.BINARY, bytes),
], ],
FIELD_TYPE.VAR_STRING: [ FIELD_TYPE.VAR_STRING: [
(FLAG.BINARY, str), (FLAG.BINARY, bytes),
], ],
FIELD_TYPE.VARCHAR: [ FIELD_TYPE.VARCHAR: [
(FLAG.BINARY, str), (FLAG.BINARY, bytes),
], ],
} }
if PY2:
conversions[unicode] = Unicode2Str
else:
conversions[bytes] = bytes
try: try:
from decimal import Decimal from decimal import Decimal

View File

@ -7,10 +7,11 @@ default, MySQLdb uses the Cursor class.
import re import re
import sys import sys
PY2 = sys.version_info[0] == 2
from MySQLdb.compat import unicode from MySQLdb.compat import unicode
restr = r""" restr = br"""
\s \s
values values
\s* \s*
@ -68,9 +69,9 @@ class BaseCursor(object):
_defer_warnings = False _defer_warnings = False
def __init__(self, connection): def __init__(self, connection):
from weakref import proxy from weakref import ref
self.connection = proxy(connection) self.connection = ref(connection)
self.description = None self.description = None
self.description_flags = None self.description_flags = None
self.rowcount = -1 self.rowcount = -1
@ -91,7 +92,8 @@ class BaseCursor(object):
def close(self): def close(self):
"""Close the cursor. No further queries will be possible.""" """Close the cursor. No further queries will be possible."""
if not self.connection: return if self.connection is None or self.connection() is None:
return
while self.nextset(): pass while self.nextset(): pass
self.connection = None self.connection = None
@ -152,9 +154,12 @@ class BaseCursor(object):
"""Does nothing, required by DB API.""" """Does nothing, required by DB API."""
def _get_db(self): def _get_db(self):
if not self.connection: con = self.connection
if con is not None:
con = con()
if con is None:
self.errorhandler(self, ProgrammingError, "cursor closed") self.errorhandler(self, ProgrammingError, "cursor closed")
return self.connection return con
def execute(self, query, args=None): def execute(self, query, args=None):
@ -172,14 +177,32 @@ class BaseCursor(object):
""" """
del self.messages[:] del self.messages[:]
db = self._get_db() db = self._get_db()
if isinstance(query, unicode): if PY2 and isinstance(query, unicode):
query = query.encode(db.unicode_literal.charset) query = query.encode(db.unicode_literal.charset)
else:
def decode(x):
if isinstance(x, bytes):
x = x.decode('ascii', 'surrogateescape')
return x
if args is not None: if args is not None:
if isinstance(args, dict): if isinstance(args, dict):
query = query % dict((key, db.literal(item)) if PY2:
for key, item in args.iteritems()) args = dict((key, db.literal(item)) for key, item in args.iteritems())
else: else:
query = query % tuple([db.literal(item) for item in args]) args = dict((key, decode(db.literal(item))) for key, item in args.items())
else:
if PY2:
args = tuple(map(db.literal, args))
else:
args = tuple([decode(db.literal(x)) for x in args])
if not PY2 and isinstance(query, bytes):
query = query.decode(db.unicode_literal.charset)
query = query % args
if isinstance(query, unicode):
query = query.encode(db.unicode_literal.charset, 'surrogateescape')
try: try:
r = None r = None
r = self._query(query) r = self._query(query)