Fix some inconsistent spacing.

Try to fix some memory leaks. I think cursors don't leak any more
but I've had no luck with connections. If you close your connections
you should be fine, even if you don't close your cursors.
This commit is contained in:
adustman
2006-03-28 05:03:35 +00:00
parent e5d609b344
commit 426d27d4ae
6 changed files with 200 additions and 163 deletions

View File

@ -30,6 +30,8 @@ def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
cursor.messages.append(error) cursor.messages.append(error)
else: else:
connection.messages.append(error) connection.messages.append(error)
del cursor
del connection
raise errorclass, errorvalue raise errorclass, errorvalue
@ -125,33 +127,30 @@ class Connection(_mysql.connection):
""" """
from constants import CLIENT, FIELD_TYPE from constants import CLIENT, FIELD_TYPE
from converters import conversions from converters import conversions
from weakref import proxy, WeakValueDictionary
import types import types
kwargs2 = kwargs.copy() kwargs2 = kwargs.copy()
if kwargs.has_key('conv'):
kwargs2['conv'] = conv = kwargs['conv'].copy()
else:
kwargs2['conv'] = conv = conversions.copy()
if kwargs.has_key('cursorclass'):
self.cursorclass = kwargs['cursorclass']
del kwargs2['cursorclass']
else:
self.cursorclass = self.default_cursor
charset = kwargs.get('charset', '') if kwargs.has_key('conv'):
if kwargs.has_key('charset'): conv = kwargs['conv']
del kwargs2['charset'] else:
conv = conversions
kwargs2['conv'] = dict([ (k, v) for k, v in conv.items()
if type(k) is int ])
self.cursorclass = kwargs2.pop('cursorclass', self.default_cursor)
charset = kwargs2.pop('charset', '')
if charset: if charset:
use_unicode = True use_unicode = True
else: else:
use_unicode = False use_unicode = False
use_unicode = kwargs.get('use_unicode', use_unicode)
if kwargs.has_key('use_unicode'):
del kwargs2['use_unicode']
sql_mode = kwargs.get('sql_mode', '') use_unicode = kwargs2.pop('use_unicode', use_unicode)
if kwargs.has_key('sql_mode'): sql_mode = kwargs2.pop('sql_mode', '')
del kwargs2['sql_mode']
client_flag = kwargs.get('client_flag', 0) client_flag = kwargs.get('client_flag', 0)
client_version = tuple([ int(n) for n in _mysql.get_client_info().split('.')[:2] ]) client_version = tuple([ int(n) for n in _mysql.get_client_info().split('.')[:2] ])
@ -164,38 +163,48 @@ class Connection(_mysql.connection):
super(Connection, self).__init__(*args, **kwargs2) super(Connection, self).__init__(*args, **kwargs2)
self.encoders = dict([ (k, v) for k, v in conv.items()
if type(k) is not int ])
self._server_version = tuple([ int(n) for n in self.get_server_info().split('.')[:2] ]) self._server_version = tuple([ int(n) for n in self.get_server_info().split('.')[:2] ])
self.charset = self.character_set_name() db = proxy(self)
if charset: def _get_string_literal():
def string_literal(obj, dummy=None):
return db.string_literal(obj)
return string_literal
def _get_unicode_literal():
def unicode_literal(u, dummy=None):
return db.literal(u.encode(unicode_literal.charset))
return unicode_literal
def _get_string_decoder():
def string_decoder(s):
return s.decode(string_decoder.charset)
return string_decoder
string_literal = _get_string_literal()
self.unicode_literal = unicode_literal = _get_unicode_literal()
self.string_decoder = string_decoder = _get_string_decoder()
if not charset:
charset = self.character_set_name()
self.set_character_set(charset) self.set_character_set(charset)
self.charset = charset
if sql_mode: if sql_mode:
self.set_sql_mode(sql_mode) self.set_sql_mode(sql_mode)
if use_unicode: if use_unicode:
def u(s): self.converter[FIELD_TYPE.STRING].insert(-1, (None, string_decoder))
# can't refer to self.character_set_name() self.converter[FIELD_TYPE.VAR_STRING].insert(-1, (None, string_decoder))
# because this results in reference cycles self.converter[FIELD_TYPE.BLOB].insert(-1, (None, string_decoder))
# and memory leaks
return s.decode(charset)
conv[FIELD_TYPE.STRING].insert(-1, (None, u))
conv[FIELD_TYPE.VAR_STRING].insert(-1, (None, u))
conv[FIELD_TYPE.BLOB].insert(-1, (None, u))
def string_literal(obj, dummy=None): self.encoders[types.StringType] = string_literal
return self.string_literal(obj) self.encoders[types.UnicodeType] = unicode_literal
def unicode_literal(u, dummy=None):
return self.literal(u.encode(self.charset))
self.converter[types.StringType] = string_literal
self.converter[types.UnicodeType] = unicode_literal
self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS
if self._transactional: if self._transactional:
# PEP-249 requires autocommit to be initially off # PEP-249 requires autocommit to be initially off
self.autocommit(0) self.autocommit(False)
self.messages = [] self.messages = []
def cursor(self, cursorclass=None): def cursor(self, cursorclass=None):
@ -220,7 +229,7 @@ class Connection(_mysql.connection):
applications. applications.
""" """
return self.escape(o, self.converter) return self.escape(o, self.encoders)
def begin(self): def begin(self):
"""Explicitly begin a connection. Non-standard. """Explicitly begin a connection. Non-standard.
@ -243,21 +252,22 @@ class Connection(_mysql.connection):
else: else:
return 0 return 0
if not hasattr(_mysql.connection, 'set_character_set'):
def set_character_set(self, charset): def set_character_set(self, charset):
"""Set the connection character set. This version """Set the connection character set to charset."""
uses the SET NAMES <charset> SQL statement. try:
super(Connection, self).set_character_set(charset)
You probably shouldn't try to change character sets except AttributeError:
after opening the connection."""
if self._server_version < (4, 1): if self._server_version < (4, 1):
raise UnsupportedError, "server is too old to set charset" raise UnsupportedError, "server is too old to set charset"
if self.charset == charset: return if self.character_set_name() != charset:
self.query('SET NAMES %s' % charset) self.query('SET NAMES %s' % charset)
self.store_result() self.store_result()
self.string_decoder.charset = charset
self.unicode_literal.charset = charset
def set_sql_mode(self, sql_mode): def set_sql_mode(self, sql_mode):
"""Set the connection sql_mode. See MySQL documentation for
legal values."""
if self._server_version < (4, 1): if self._server_version < (4, 1):
raise UnsupportedError, "server is too old to set sql_mode" raise UnsupportedError, "server is too old to set sql_mode"
self.query("SET SESSION sql_mode='%s'" % sql_mode) self.query("SET SESSION sql_mode='%s'" % sql_mode)

View File

@ -36,7 +36,9 @@ class BaseCursor(object):
InternalError, ProgrammingError, NotSupportedError InternalError, ProgrammingError, NotSupportedError
def __init__(self, connection): def __init__(self, connection):
self.connection = connection from weakref import proxy
self.connection = proxy(connection)
self.description = None self.description = None
self.description_flags = None self.description_flags = None
self.rowcount = -1 self.rowcount = -1
@ -68,7 +70,7 @@ class BaseCursor(object):
def _warning_check(self): def _warning_check(self):
from warnings import warn from warnings import warn
if self._warnings: if self._warnings:
warnings = self.connection.show_warnings() warnings = self._get_db().show_warnings()
if warnings: if warnings:
# This is done in two loops in case # This is done in two loops in case
# Warnings are set to raise exceptions. # Warnings are set to raise exceptions.
@ -101,7 +103,7 @@ class BaseCursor(object):
def _post_get_result(self): pass def _post_get_result(self): pass
def _do_get_result(self): def _do_get_result(self):
db = self.connection db = self._get_db()
self._result = self._get_result() self._result = self._get_result()
self.rowcount = db.affected_rows() self.rowcount = db.affected_rows()
self.rownumber = 0 self.rownumber = 0
@ -139,9 +141,11 @@ class BaseCursor(object):
from types import ListType, TupleType from types import ListType, TupleType
from sys import exc_info from sys import exc_info
del self.messages[:] del self.messages[:]
query = query.encode(self.connection.charset) db = self._get_db()
charset = db.character_set_name()
query = query.encode(charset)
if args is not None: if args is not None:
query = query % self.connection.literal(args) query = query % db.literal(args)
try: try:
r = self._query(query) r = self._query(query)
except TypeError, m: except TypeError, m:
@ -180,6 +184,7 @@ class BaseCursor(object):
""" """
del self.messages[:] del self.messages[:]
db = self._get_db()
if not args: return if not args: return
m = insert_values.search(query) m = insert_values.search(query)
if not m: if not m:
@ -188,9 +193,10 @@ class BaseCursor(object):
r = r + self.execute(query, a) r = r + self.execute(query, a)
return r return r
p = m.start(1) p = m.start(1)
query = query.encode(self.connection.charset) charset = db.character_set_name()
query = query.encode(charset)
qv = query[p:] qv = query[p:]
qargs = self.connection.literal(args) qargs = db.literal(args)
try: try:
q = [ query % qargs[0] ] q = [ query % qargs[0] ]
q.extend([ qv % a for a in qargs[1:] ]) q.extend([ qv % a for a in qargs[1:] ])
@ -243,11 +249,12 @@ class BaseCursor(object):
from types import UnicodeType from types import UnicodeType
db = self._get_db() db = self._get_db()
charset = db.character_set_name()
for index, arg in enumerate(args): for index, arg in enumerate(args):
q = "SET @_%s_%d=%s" % (procname, index, q = "SET @_%s_%d=%s" % (procname, index,
db.literal(arg)) db.literal(arg))
if type(q) is UnicodeType: if type(q) is UnicodeType:
q = q.encode(db.charset) q = q.encode(charset)
self._query(q) self._query(q)
self.nextset() self.nextset()
@ -255,7 +262,7 @@ class BaseCursor(object):
','.join(['@_%s_%d' % (procname, i) ','.join(['@_%s_%d' % (procname, i)
for i in range(len(args))])) for i in range(len(args))]))
if type(q) is UnicodeType: if type(q) is UnicodeType:
q = q.encode(db.charset) q = q.encode(charset)
self._query(q) self._query(q)
self._warning_check() self._warning_check()
return args return args

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
version: 1.2.1c7 version: 1.2.1c8
version_info: (1,2,1,'gamma',7) version_info: (1,2,1,'gamma',8)
description: Python interface to MySQL description: Python interface to MySQL
long_description: long_description:
========================= =========================

View File

@ -13,6 +13,7 @@ class test_MySQLdb(test_capabilities.DatabaseTest):
connect_kwargs = dict(db='test', read_default_file='~/.my.cnf', connect_kwargs = dict(db='test', read_default_file='~/.my.cnf',
charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL") charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")
create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8" create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
leak_test = True
def quote_identifier(self, ident): def quote_identifier(self, ident):
return "`%s`" % ident return "`%s`" % ident
@ -76,5 +77,9 @@ class test_MySQLdb(test_capabilities.DatabaseTest):
if __name__ == '__main__': if __name__ == '__main__':
if test_MySQLdb.leak_test:
import gc
gc.enable()
gc.set_debug(gc.DEBUG_LEAK)
unittest.main() unittest.main()
print '''"Huh-huh, he said 'unit'." -- Butthead''' print '''"Huh-huh, he said 'unit'." -- Butthead'''

View File

@ -190,8 +190,8 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
empty = cur.fetchall() empty = cur.fetchall()
self.assertEquals(len(empty), 0, self.assertEquals(len(empty), 0,
"non-empty result set after other result sets") "non-empty result set after other result sets")
warn("Incompatibility: MySQL returns an empty result set for the CALL itself", #warn("Incompatibility: MySQL returns an empty result set for the CALL itself",
Warning) # Warning)
#assert s == None,'No more return sets, should return None' #assert s == None,'No more return sets, should return None'
finally: finally:
self.help_nextset_tearDown(cur) self.help_nextset_tearDown(cur)

View File

@ -20,6 +20,7 @@ class DatabaseTest(unittest.TestCase):
debug = False debug = False
def setUp(self): def setUp(self):
import gc
db = self.db_module.connect(*self.connect_args, **self.connect_kwargs) db = self.db_module.connect(*self.connect_args, **self.connect_kwargs)
self.connection = db self.connection = db
self.cursor = db.cursor() self.cursor = db.cursor()
@ -27,6 +28,20 @@ class DatabaseTest(unittest.TestCase):
self.BLOBUText = u''.join([unichr(i) for i in range(16384)]) self.BLOBUText = u''.join([unichr(i) for i in range(16384)])
self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16)) self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16))
leak_test = True
if leak_test:
def tearDown(self):
import gc
del self.cursor
orphans = gc.collect()
self.failIf(orphans, "%d orphaned objects found after deleting cursor" % orphans)
del self.connection
orphans = gc.collect()
self.failIf(orphans, "%d orphaned objects found after deleting connection" % orphans)
def table_exists(self, name): def table_exists(self, name):
try: try:
self.cursor.execute('select * from %s where 1=0' % name) self.cursor.execute('select * from %s where 1=0' % name)