Fix some inconsistent spacing.

Try to fix some memory leaks. I think cursors don't leak any more
but I've had no luck with connections. If you close your connections
you should be fine, even if you don't close your cursors.
This commit is contained in:
adustman
2006-03-28 05:03:35 +00:00
parent e5d609b344
commit 426d27d4ae
6 changed files with 200 additions and 163 deletions

View File

@ -30,6 +30,8 @@ def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
cursor.messages.append(error)
else:
connection.messages.append(error)
del cursor
del connection
raise errorclass, errorvalue
@ -125,33 +127,30 @@ class Connection(_mysql.connection):
"""
from constants import CLIENT, FIELD_TYPE
from converters import conversions
from weakref import proxy, WeakValueDictionary
import types
kwargs2 = kwargs.copy()
if kwargs.has_key('conv'):
kwargs2['conv'] = conv = kwargs['conv'].copy()
conv = kwargs['conv']
else:
kwargs2['conv'] = conv = conversions.copy()
if kwargs.has_key('cursorclass'):
self.cursorclass = kwargs['cursorclass']
del kwargs2['cursorclass']
else:
self.cursorclass = self.default_cursor
conv = conversions
kwargs2['conv'] = dict([ (k, v) for k, v in conv.items()
if type(k) is int ])
self.cursorclass = kwargs2.pop('cursorclass', self.default_cursor)
charset = kwargs2.pop('charset', '')
charset = kwargs.get('charset', '')
if kwargs.has_key('charset'):
del kwargs2['charset']
if charset:
use_unicode = True
else:
use_unicode = False
use_unicode = kwargs.get('use_unicode', use_unicode)
if kwargs.has_key('use_unicode'):
del kwargs2['use_unicode']
sql_mode = kwargs.get('sql_mode', '')
if kwargs.has_key('sql_mode'):
del kwargs2['sql_mode']
use_unicode = kwargs2.pop('use_unicode', use_unicode)
sql_mode = kwargs2.pop('sql_mode', '')
client_flag = kwargs.get('client_flag', 0)
client_version = tuple([ int(n) for n in _mysql.get_client_info().split('.')[:2] ])
@ -164,38 +163,48 @@ class Connection(_mysql.connection):
super(Connection, self).__init__(*args, **kwargs2)
self.encoders = dict([ (k, v) for k, v in conv.items()
if type(k) is not int ])
self._server_version = tuple([ int(n) for n in self.get_server_info().split('.')[:2] ])
self.charset = self.character_set_name()
if charset:
self.set_character_set(charset)
self.charset = charset
db = proxy(self)
def _get_string_literal():
def string_literal(obj, dummy=None):
return db.string_literal(obj)
return string_literal
def _get_unicode_literal():
def unicode_literal(u, dummy=None):
return db.literal(u.encode(unicode_literal.charset))
return unicode_literal
def _get_string_decoder():
def string_decoder(s):
return s.decode(string_decoder.charset)
return string_decoder
string_literal = _get_string_literal()
self.unicode_literal = unicode_literal = _get_unicode_literal()
self.string_decoder = string_decoder = _get_string_decoder()
if not charset:
charset = self.character_set_name()
self.set_character_set(charset)
if sql_mode:
self.set_sql_mode(sql_mode)
if use_unicode:
def u(s):
# can't refer to self.character_set_name()
# because this results in reference cycles
# and memory leaks
return s.decode(charset)
conv[FIELD_TYPE.STRING].insert(-1, (None, u))
conv[FIELD_TYPE.VAR_STRING].insert(-1, (None, u))
conv[FIELD_TYPE.BLOB].insert(-1, (None, u))
self.converter[FIELD_TYPE.STRING].insert(-1, (None, string_decoder))
self.converter[FIELD_TYPE.VAR_STRING].insert(-1, (None, string_decoder))
self.converter[FIELD_TYPE.BLOB].insert(-1, (None, string_decoder))
def string_literal(obj, dummy=None):
return self.string_literal(obj)
def unicode_literal(u, dummy=None):
return self.literal(u.encode(self.charset))
self.converter[types.StringType] = string_literal
self.converter[types.UnicodeType] = unicode_literal
self.encoders[types.StringType] = string_literal
self.encoders[types.UnicodeType] = unicode_literal
self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS
if self._transactional:
# PEP-249 requires autocommit to be initially off
self.autocommit(0)
self.autocommit(False)
self.messages = []
def cursor(self, cursorclass=None):
@ -220,7 +229,7 @@ class Connection(_mysql.connection):
applications.
"""
return self.escape(o, self.converter)
return self.escape(o, self.encoders)
def begin(self):
"""Explicitly begin a connection. Non-standard.
@ -243,21 +252,22 @@ class Connection(_mysql.connection):
else:
return 0
if not hasattr(_mysql.connection, 'set_character_set'):
def set_character_set(self, charset):
"""Set the connection character set. This version
uses the SET NAMES <charset> SQL statement.
You probably shouldn't try to change character sets
after opening the connection."""
def set_character_set(self, charset):
"""Set the connection character set to charset."""
try:
super(Connection, self).set_character_set(charset)
except AttributeError:
if self._server_version < (4, 1):
raise UnsupportedError, "server is too old to set charset"
if self.charset == charset: return
self.query('SET NAMES %s' % charset)
self.store_result()
if self.character_set_name() != charset:
self.query('SET NAMES %s' % charset)
self.store_result()
self.string_decoder.charset = charset
self.unicode_literal.charset = charset
def set_sql_mode(self, sql_mode):
"""Set the connection sql_mode. See MySQL documentation for
legal values."""
if self._server_version < (4, 1):
raise UnsupportedError, "server is too old to set sql_mode"
self.query("SET SESSION sql_mode='%s'" % sql_mode)

View File

@ -36,7 +36,9 @@ class BaseCursor(object):
InternalError, ProgrammingError, NotSupportedError
def __init__(self, connection):
self.connection = connection
from weakref import proxy
self.connection = proxy(connection)
self.description = None
self.description_flags = None
self.rowcount = -1
@ -68,7 +70,7 @@ class BaseCursor(object):
def _warning_check(self):
from warnings import warn
if self._warnings:
warnings = self.connection.show_warnings()
warnings = self._get_db().show_warnings()
if warnings:
# This is done in two loops in case
# Warnings are set to raise exceptions.
@ -101,7 +103,7 @@ class BaseCursor(object):
def _post_get_result(self): pass
def _do_get_result(self):
db = self.connection
db = self._get_db()
self._result = self._get_result()
self.rowcount = db.affected_rows()
self.rownumber = 0
@ -139,9 +141,11 @@ class BaseCursor(object):
from types import ListType, TupleType
from sys import exc_info
del self.messages[:]
query = query.encode(self.connection.charset)
db = self._get_db()
charset = db.character_set_name()
query = query.encode(charset)
if args is not None:
query = query % self.connection.literal(args)
query = query % db.literal(args)
try:
r = self._query(query)
except TypeError, m:
@ -180,6 +184,7 @@ class BaseCursor(object):
"""
del self.messages[:]
db = self._get_db()
if not args: return
m = insert_values.search(query)
if not m:
@ -188,9 +193,10 @@ class BaseCursor(object):
r = r + self.execute(query, a)
return r
p = m.start(1)
query = query.encode(self.connection.charset)
charset = db.character_set_name()
query = query.encode(charset)
qv = query[p:]
qargs = self.connection.literal(args)
qargs = db.literal(args)
try:
q = [ query % qargs[0] ]
q.extend([ qv % a for a in qargs[1:] ])
@ -243,11 +249,12 @@ class BaseCursor(object):
from types import UnicodeType
db = self._get_db()
charset = db.character_set_name()
for index, arg in enumerate(args):
q = "SET @_%s_%d=%s" % (procname, index,
db.literal(arg))
if type(q) is UnicodeType:
q = q.encode(db.charset)
q = q.encode(charset)
self._query(q)
self.nextset()
@ -255,7 +262,7 @@ class BaseCursor(object):
','.join(['@_%s_%d' % (procname, i)
for i in range(len(args))]))
if type(q) is UnicodeType:
q = q.encode(db.charset)
q = q.encode(charset)
self._query(q)
self._warning_check()
return args

View File

@ -1,6 +1,6 @@
[metadata]
version: 1.2.1c7
version_info: (1,2,1,'gamma',7)
version: 1.2.1c8
version_info: (1,2,1,'gamma',8)
description: Python interface to MySQL
long_description:
=========================

View File

@ -13,21 +13,22 @@ class test_MySQLdb(test_capabilities.DatabaseTest):
connect_kwargs = dict(db='test', read_default_file='~/.my.cnf',
charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")
create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
leak_test = True
def quote_identifier(self, ident):
return "`%s`" % ident
def test_TIME(self):
from datetime import timedelta
def generator(row,col):
def generator(row,col):
return timedelta(0, row*8000)
self.check_data_integrity(
('col1 TIME',),
generator)
('col1 TIME',),
generator)
def test_TINYINT(self):
# Number data
def generator(row,col):
# Number data
def generator(row,col):
v = (row*row) % 256
if v > 127:
v = v-256
@ -76,5 +77,9 @@ class test_MySQLdb(test_capabilities.DatabaseTest):
if __name__ == '__main__':
if test_MySQLdb.leak_test:
import gc
gc.enable()
gc.set_debug(gc.DEBUG_LEAK)
unittest.main()
print '''"Huh-huh, he said 'unit'." -- Butthead'''

View File

@ -190,8 +190,8 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
empty = cur.fetchall()
self.assertEquals(len(empty), 0,
"non-empty result set after other result sets")
warn("Incompatibility: MySQL returns an empty result set for the CALL itself",
Warning)
#warn("Incompatibility: MySQL returns an empty result set for the CALL itself",
# Warning)
#assert s == None,'No more return sets, should return None'
finally:
self.help_nextset_tearDown(cur)

View File

@ -20,63 +20,78 @@ class DatabaseTest(unittest.TestCase):
debug = False
def setUp(self):
import gc
db = self.db_module.connect(*self.connect_args, **self.connect_kwargs)
self.connection = db
self.cursor = db.cursor()
self.cursor = db.cursor()
self.BLOBText = ''.join([chr(i) for i in range(256)] * 100);
self.BLOBUText = u''.join([unichr(i) for i in range(16384)])
self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16))
leak_test = True
if leak_test:
def tearDown(self):
import gc
del self.cursor
orphans = gc.collect()
self.failIf(orphans, "%d orphaned objects found after deleting cursor" % orphans)
del self.connection
orphans = gc.collect()
self.failIf(orphans, "%d orphaned objects found after deleting connection" % orphans)
def table_exists(self, name):
try:
self.cursor.execute('select * from %s where 1=0' % name)
except:
return False
else:
return True
try:
self.cursor.execute('select * from %s where 1=0' % name)
except:
return False
else:
return True
def quote_identifier(self, ident):
return '"%s"' % ident
def new_table_name(self):
i = id(self.cursor)
while True:
name = self.quote_identifier('tb%08x' % i)
if not self.table_exists(name):
return name
i = i + 1
i = id(self.cursor)
while True:
name = self.quote_identifier('tb%08x' % i)
if not self.table_exists(name):
return name
i = i + 1
def create_table(self, columndefs):
""" Create a table using a list of column definitions given in
columndefs.
generator must be a function taking arguments (row_number,
col_number) returning a suitable data object for insertion
into the table.
""" Create a table using a list of column definitions given in
columndefs.
generator must be a function taking arguments (row_number,
col_number) returning a suitable data object for insertion
into the table.
"""
self.table = self.new_table_name()
self.cursor.execute('CREATE TABLE %s (%s) %s' %
(self.table,
"""
self.table = self.new_table_name()
self.cursor.execute('CREATE TABLE %s (%s) %s' %
(self.table,
',\n'.join(columndefs),
self.create_table_extra))
def check_data_integrity(self, columndefs, generator):
# insert
# insert
self.create_table(columndefs)
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
for i in range(self.rows) ]
if self.debug:
print data
self.cursor.executemany(insert_statement, data)
self.connection.commit()
# verify
self.cursor.execute('select * from %s' % self.table)
l = self.cursor.fetchall()
# verify
self.cursor.execute('select * from %s' % self.table)
l = self.cursor.fetchall()
if self.debug:
print l
self.assertEquals(len(l), self.rows)
@ -94,20 +109,20 @@ class DatabaseTest(unittest.TestCase):
if col == 0: return row
else: return ('%i' % (row%10))*255
self.create_table(columndefs)
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
for i in range(self.rows) ]
self.cursor.executemany(insert_statement, data)
# verify
# verify
self.connection.commit()
self.cursor.execute('select * from %s' % self.table)
l = self.cursor.fetchall()
self.cursor.execute('select * from %s' % self.table)
l = self.cursor.fetchall()
self.assertEquals(len(l), self.rows)
for i in range(self.rows):
for j in range(len(columndefs)):
self.assertEquals(l[i][j], generator(i,j))
for i in range(self.rows):
for j in range(len(columndefs)):
self.assertEquals(l[i][j], generator(i,j))
delete_statement = 'delete from %s where col1=%%s' % self.table
self.cursor.execute(delete_statement, (0,))
self.cursor.execute('select col1 from %s where col1=%s' % \
@ -119,7 +134,7 @@ class DatabaseTest(unittest.TestCase):
(self.table, 0))
l = self.cursor.fetchall()
self.failUnless(len(l) == 1, "ROLLBACK didn't work")
self.cursor.execute('drop table %s' % (self.table))
self.cursor.execute('drop table %s' % (self.table))
def test_truncation(self):
columndefs = ( 'col1 INT', 'col2 VARCHAR(255)')
@ -127,9 +142,9 @@ class DatabaseTest(unittest.TestCase):
if col == 0: return row
else: return ('%i' % (row%10))*((255-self.rows/2)+row)
self.create_table(columndefs)
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
try:
self.cursor.execute(insert_statement, (0, '0'*256))
@ -169,27 +184,27 @@ class DatabaseTest(unittest.TestCase):
self.fail("Over-long columns did not generate warnings/exception with executemany()")
self.connection.rollback()
self.cursor.execute('drop table %s' % (self.table))
self.cursor.execute('drop table %s' % (self.table))
def test_CHAR(self):
# Character data
def generator(row,col):
return ('%i' % ((row+col) % 10)) * 255
# Character data
def generator(row,col):
return ('%i' % ((row+col) % 10)) * 255
self.check_data_integrity(
('col1 char(255)','col2 char(255)'),
generator)
def test_INT(self):
# Number data
def generator(row,col):
# Number data
def generator(row,col):
return row*row
self.check_data_integrity(
('col1 INT',),
generator)
def test_DECIMAL(self):
# DECIMAL
def generator(row,col):
# DECIMAL
def generator(row,col):
from decimal import Decimal
return Decimal("%d.%02d" % (row, col))
self.check_data_integrity(
@ -198,78 +213,78 @@ class DatabaseTest(unittest.TestCase):
def test_DATE(self):
ticks = time()
def generator(row,col):
def generator(row,col):
return self.db_module.DateFromTicks(ticks+row*86400-col*1313)
self.check_data_integrity(
('col1 DATE',),
generator)
('col1 DATE',),
generator)
def test_TIME(self):
ticks = time()
def generator(row,col):
def generator(row,col):
return self.db_module.TimeFromTicks(ticks+row*86400-col*1313)
self.check_data_integrity(
('col1 TIME',),
generator)
('col1 TIME',),
generator)
def test_DATETIME(self):
ticks = time()
def generator(row,col):
def generator(row,col):
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
self.check_data_integrity(
('col1 DATETIME',),
generator)
('col1 DATETIME',),
generator)
def test_TIMESTAMP(self):
ticks = time()
def generator(row,col):
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
self.check_data_integrity(
('col1 TIMESTAMP',),
generator)
def generator(row,col):
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
self.check_data_integrity(
('col1 TIMESTAMP',),
generator)
def test_fractional_TIMESTAMP(self):
ticks = time()
def generator(row,col):
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0)
self.check_data_integrity(
('col1 TIMESTAMP',),
generator)
def generator(row,col):
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0)
self.check_data_integrity(
('col1 TIMESTAMP',),
generator)
def test_LONG(self):
def generator(row,col):
if col == 0:
return row
else:
return self.BLOBUText # 'BLOB Text ' * 1024
self.check_data_integrity(
('col1 INT','col2 LONG'),
generator)
def generator(row,col):
if col == 0:
return row
else:
return self.BLOBUText # 'BLOB Text ' * 1024
self.check_data_integrity(
('col1 INT','col2 LONG'),
generator)
def test_TEXT(self):
def generator(row,col):
def generator(row,col):
return self.BLOBUText # 'BLOB Text ' * 1024
self.check_data_integrity(
('col2 TEXT',),
generator)
('col2 TEXT',),
generator)
def test_LONG_BYTE(self):
def generator(row,col):
if col == 0:
return row
else:
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
def generator(row,col):
if col == 0:
return row
else:
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
self.check_data_integrity(
('col1 INT','col2 LONG BYTE'),
generator)
('col1 INT','col2 LONG BYTE'),
generator)
def test_BLOB(self):
def generator(row,col):
if col == 0:
return row
else:
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
def generator(row,col):
if col == 0:
return row
else:
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
self.check_data_integrity(
('col1 INT','col2 BLOB'),
generator)
('col1 INT','col2 BLOB'),
generator)