diff --git a/MySQLdb/MySQLdb/connections.py b/MySQLdb/MySQLdb/connections.py index dbaf51f..4482a34 100644 --- a/MySQLdb/MySQLdb/connections.py +++ b/MySQLdb/MySQLdb/connections.py @@ -30,6 +30,8 @@ def defaulterrorhandler(connection, cursor, errorclass, errorvalue): cursor.messages.append(error) else: connection.messages.append(error) + del cursor + del connection raise errorclass, errorvalue @@ -125,33 +127,30 @@ class Connection(_mysql.connection): """ from constants import CLIENT, FIELD_TYPE from converters import conversions + from weakref import proxy, WeakValueDictionary + import types kwargs2 = kwargs.copy() + if kwargs.has_key('conv'): - kwargs2['conv'] = conv = kwargs['conv'].copy() + conv = kwargs['conv'] else: - kwargs2['conv'] = conv = conversions.copy() - if kwargs.has_key('cursorclass'): - self.cursorclass = kwargs['cursorclass'] - del kwargs2['cursorclass'] - else: - self.cursorclass = self.default_cursor + conv = conversions + + kwargs2['conv'] = dict([ (k, v) for k, v in conv.items() + if type(k) is int ]) + + self.cursorclass = kwargs2.pop('cursorclass', self.default_cursor) + charset = kwargs2.pop('charset', '') - charset = kwargs.get('charset', '') - if kwargs.has_key('charset'): - del kwargs2['charset'] if charset: use_unicode = True else: use_unicode = False - use_unicode = kwargs.get('use_unicode', use_unicode) - if kwargs.has_key('use_unicode'): - del kwargs2['use_unicode'] - sql_mode = kwargs.get('sql_mode', '') - if kwargs.has_key('sql_mode'): - del kwargs2['sql_mode'] + use_unicode = kwargs2.pop('use_unicode', use_unicode) + sql_mode = kwargs2.pop('sql_mode', '') client_flag = kwargs.get('client_flag', 0) client_version = tuple([ int(n) for n in _mysql.get_client_info().split('.')[:2] ]) @@ -164,38 +163,48 @@ class Connection(_mysql.connection): super(Connection, self).__init__(*args, **kwargs2) + self.encoders = dict([ (k, v) for k, v in conv.items() + if type(k) is not int ]) + self._server_version = tuple([ int(n) for n in self.get_server_info().split('.')[:2] ]) - self.charset = self.character_set_name() - if charset: - self.set_character_set(charset) - self.charset = charset + db = proxy(self) + def _get_string_literal(): + def string_literal(obj, dummy=None): + return db.string_literal(obj) + return string_literal + + def _get_unicode_literal(): + def unicode_literal(u, dummy=None): + return db.literal(u.encode(unicode_literal.charset)) + return unicode_literal + + def _get_string_decoder(): + def string_decoder(s): + return s.decode(string_decoder.charset) + return string_decoder + + string_literal = _get_string_literal() + self.unicode_literal = unicode_literal = _get_unicode_literal() + self.string_decoder = string_decoder = _get_string_decoder() + if not charset: + charset = self.character_set_name() + self.set_character_set(charset) if sql_mode: self.set_sql_mode(sql_mode) if use_unicode: - def u(s): - # can't refer to self.character_set_name() - # because this results in reference cycles - # and memory leaks - return s.decode(charset) - conv[FIELD_TYPE.STRING].insert(-1, (None, u)) - conv[FIELD_TYPE.VAR_STRING].insert(-1, (None, u)) - conv[FIELD_TYPE.BLOB].insert(-1, (None, u)) + self.converter[FIELD_TYPE.STRING].insert(-1, (None, string_decoder)) + self.converter[FIELD_TYPE.VAR_STRING].insert(-1, (None, string_decoder)) + self.converter[FIELD_TYPE.BLOB].insert(-1, (None, string_decoder)) - def string_literal(obj, dummy=None): - return self.string_literal(obj) - - def unicode_literal(u, dummy=None): - return self.literal(u.encode(self.charset)) - - self.converter[types.StringType] = string_literal - self.converter[types.UnicodeType] = unicode_literal + self.encoders[types.StringType] = string_literal + self.encoders[types.UnicodeType] = unicode_literal self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS if self._transactional: # PEP-249 requires autocommit to be initially off - self.autocommit(0) + self.autocommit(False) self.messages = [] def cursor(self, cursorclass=None): @@ -220,7 +229,7 @@ class Connection(_mysql.connection): applications. """ - return self.escape(o, self.converter) + return self.escape(o, self.encoders) def begin(self): """Explicitly begin a connection. Non-standard. @@ -243,21 +252,22 @@ class Connection(_mysql.connection): else: return 0 - if not hasattr(_mysql.connection, 'set_character_set'): - - def set_character_set(self, charset): - """Set the connection character set. This version - uses the SET NAMES SQL statement. - - You probably shouldn't try to change character sets - after opening the connection.""" + def set_character_set(self, charset): + """Set the connection character set to charset.""" + try: + super(Connection, self).set_character_set(charset) + except AttributeError: if self._server_version < (4, 1): raise UnsupportedError, "server is too old to set charset" - if self.charset == charset: return - self.query('SET NAMES %s' % charset) - self.store_result() + if self.character_set_name() != charset: + self.query('SET NAMES %s' % charset) + self.store_result() + self.string_decoder.charset = charset + self.unicode_literal.charset = charset def set_sql_mode(self, sql_mode): + """Set the connection sql_mode. See MySQL documentation for + legal values.""" if self._server_version < (4, 1): raise UnsupportedError, "server is too old to set sql_mode" self.query("SET SESSION sql_mode='%s'" % sql_mode) diff --git a/MySQLdb/MySQLdb/cursors.py b/MySQLdb/MySQLdb/cursors.py index 9d8433e..0ee0955 100644 --- a/MySQLdb/MySQLdb/cursors.py +++ b/MySQLdb/MySQLdb/cursors.py @@ -36,7 +36,9 @@ class BaseCursor(object): InternalError, ProgrammingError, NotSupportedError def __init__(self, connection): - self.connection = connection + from weakref import proxy + + self.connection = proxy(connection) self.description = None self.description_flags = None self.rowcount = -1 @@ -68,7 +70,7 @@ class BaseCursor(object): def _warning_check(self): from warnings import warn if self._warnings: - warnings = self.connection.show_warnings() + warnings = self._get_db().show_warnings() if warnings: # This is done in two loops in case # Warnings are set to raise exceptions. @@ -101,7 +103,7 @@ class BaseCursor(object): def _post_get_result(self): pass def _do_get_result(self): - db = self.connection + db = self._get_db() self._result = self._get_result() self.rowcount = db.affected_rows() self.rownumber = 0 @@ -139,9 +141,11 @@ class BaseCursor(object): from types import ListType, TupleType from sys import exc_info del self.messages[:] - query = query.encode(self.connection.charset) + db = self._get_db() + charset = db.character_set_name() + query = query.encode(charset) if args is not None: - query = query % self.connection.literal(args) + query = query % db.literal(args) try: r = self._query(query) except TypeError, m: @@ -180,6 +184,7 @@ class BaseCursor(object): """ del self.messages[:] + db = self._get_db() if not args: return m = insert_values.search(query) if not m: @@ -188,9 +193,10 @@ class BaseCursor(object): r = r + self.execute(query, a) return r p = m.start(1) - query = query.encode(self.connection.charset) + charset = db.character_set_name() + query = query.encode(charset) qv = query[p:] - qargs = self.connection.literal(args) + qargs = db.literal(args) try: q = [ query % qargs[0] ] q.extend([ qv % a for a in qargs[1:] ]) @@ -243,11 +249,12 @@ class BaseCursor(object): from types import UnicodeType db = self._get_db() + charset = db.character_set_name() for index, arg in enumerate(args): q = "SET @_%s_%d=%s" % (procname, index, db.literal(arg)) if type(q) is UnicodeType: - q = q.encode(db.charset) + q = q.encode(charset) self._query(q) self.nextset() @@ -255,7 +262,7 @@ class BaseCursor(object): ','.join(['@_%s_%d' % (procname, i) for i in range(len(args))])) if type(q) is UnicodeType: - q = q.encode(db.charset) + q = q.encode(charset) self._query(q) self._warning_check() return args diff --git a/MySQLdb/metadata.cfg b/MySQLdb/metadata.cfg index ef2a13a..047d196 100644 --- a/MySQLdb/metadata.cfg +++ b/MySQLdb/metadata.cfg @@ -1,6 +1,6 @@ [metadata] -version: 1.2.1c7 -version_info: (1,2,1,'gamma',7) +version: 1.2.1c8 +version_info: (1,2,1,'gamma',8) description: Python interface to MySQL long_description: ========================= diff --git a/MySQLdb/test_MySQLdb_capabilities.py b/MySQLdb/test_MySQLdb_capabilities.py index 281c259..588b308 100644 --- a/MySQLdb/test_MySQLdb_capabilities.py +++ b/MySQLdb/test_MySQLdb_capabilities.py @@ -13,21 +13,22 @@ class test_MySQLdb(test_capabilities.DatabaseTest): connect_kwargs = dict(db='test', read_default_file='~/.my.cnf', charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL") create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8" - + leak_test = True + def quote_identifier(self, ident): return "`%s`" % ident def test_TIME(self): from datetime import timedelta - def generator(row,col): + def generator(row,col): return timedelta(0, row*8000) self.check_data_integrity( - ('col1 TIME',), - generator) + ('col1 TIME',), + generator) def test_TINYINT(self): - # Number data - def generator(row,col): + # Number data + def generator(row,col): v = (row*row) % 256 if v > 127: v = v-256 @@ -76,5 +77,9 @@ class test_MySQLdb(test_capabilities.DatabaseTest): if __name__ == '__main__': + if test_MySQLdb.leak_test: + import gc + gc.enable() + gc.set_debug(gc.DEBUG_LEAK) unittest.main() print '''"Huh-huh, he said 'unit'." -- Butthead''' diff --git a/MySQLdb/test_MySQLdb_dbapi20.py b/MySQLdb/test_MySQLdb_dbapi20.py index f866088..a466611 100644 --- a/MySQLdb/test_MySQLdb_dbapi20.py +++ b/MySQLdb/test_MySQLdb_dbapi20.py @@ -190,8 +190,8 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): empty = cur.fetchall() self.assertEquals(len(empty), 0, "non-empty result set after other result sets") - warn("Incompatibility: MySQL returns an empty result set for the CALL itself", - Warning) + #warn("Incompatibility: MySQL returns an empty result set for the CALL itself", + # Warning) #assert s == None,'No more return sets, should return None' finally: self.help_nextset_tearDown(cur) diff --git a/MySQLdb/test_capabilities.py b/MySQLdb/test_capabilities.py index 3edfb02..40b6874 100644 --- a/MySQLdb/test_capabilities.py +++ b/MySQLdb/test_capabilities.py @@ -20,63 +20,78 @@ class DatabaseTest(unittest.TestCase): debug = False def setUp(self): + import gc db = self.db_module.connect(*self.connect_args, **self.connect_kwargs) self.connection = db - self.cursor = db.cursor() + self.cursor = db.cursor() self.BLOBText = ''.join([chr(i) for i in range(256)] * 100); self.BLOBUText = u''.join([unichr(i) for i in range(16384)]) self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16)) + leak_test = True + + if leak_test: + + def tearDown(self): + import gc + del self.cursor + orphans = gc.collect() + self.failIf(orphans, "%d orphaned objects found after deleting cursor" % orphans) + + del self.connection + orphans = gc.collect() + self.failIf(orphans, "%d orphaned objects found after deleting connection" % orphans) + def table_exists(self, name): - try: - self.cursor.execute('select * from %s where 1=0' % name) - except: - return False - else: - return True + try: + self.cursor.execute('select * from %s where 1=0' % name) + except: + return False + else: + return True def quote_identifier(self, ident): return '"%s"' % ident def new_table_name(self): - i = id(self.cursor) - while True: - name = self.quote_identifier('tb%08x' % i) - if not self.table_exists(name): - return name - i = i + 1 + i = id(self.cursor) + while True: + name = self.quote_identifier('tb%08x' % i) + if not self.table_exists(name): + return name + i = i + 1 def create_table(self, columndefs): - """ Create a table using a list of column definitions given in - columndefs. - - generator must be a function taking arguments (row_number, - col_number) returning a suitable data object for insertion - into the table. + """ Create a table using a list of column definitions given in + columndefs. + + generator must be a function taking arguments (row_number, + col_number) returning a suitable data object for insertion + into the table. - """ - self.table = self.new_table_name() - self.cursor.execute('CREATE TABLE %s (%s) %s' % - (self.table, + """ + self.table = self.new_table_name() + self.cursor.execute('CREATE TABLE %s (%s) %s' % + (self.table, ',\n'.join(columndefs), self.create_table_extra)) def check_data_integrity(self, columndefs, generator): - # insert + # insert self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) + insert_statement = ('INSERT INTO %s VALUES (%s)' % + (self.table, + ','.join(['%s'] * len(columndefs)))) data = [ [ generator(i,j) for j in range(len(columndefs)) ] for i in range(self.rows) ] if self.debug: print data self.cursor.executemany(insert_statement, data) self.connection.commit() - # verify - self.cursor.execute('select * from %s' % self.table) - l = self.cursor.fetchall() + # verify + self.cursor.execute('select * from %s' % self.table) + l = self.cursor.fetchall() if self.debug: print l self.assertEquals(len(l), self.rows) @@ -94,20 +109,20 @@ class DatabaseTest(unittest.TestCase): if col == 0: return row else: return ('%i' % (row%10))*255 self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) + insert_statement = ('INSERT INTO %s VALUES (%s)' % + (self.table, + ','.join(['%s'] * len(columndefs)))) data = [ [ generator(i,j) for j in range(len(columndefs)) ] for i in range(self.rows) ] self.cursor.executemany(insert_statement, data) - # verify + # verify self.connection.commit() - self.cursor.execute('select * from %s' % self.table) - l = self.cursor.fetchall() + self.cursor.execute('select * from %s' % self.table) + l = self.cursor.fetchall() self.assertEquals(len(l), self.rows) - for i in range(self.rows): - for j in range(len(columndefs)): - self.assertEquals(l[i][j], generator(i,j)) + for i in range(self.rows): + for j in range(len(columndefs)): + self.assertEquals(l[i][j], generator(i,j)) delete_statement = 'delete from %s where col1=%%s' % self.table self.cursor.execute(delete_statement, (0,)) self.cursor.execute('select col1 from %s where col1=%s' % \ @@ -119,7 +134,7 @@ class DatabaseTest(unittest.TestCase): (self.table, 0)) l = self.cursor.fetchall() self.failUnless(len(l) == 1, "ROLLBACK didn't work") - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute('drop table %s' % (self.table)) def test_truncation(self): columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') @@ -127,9 +142,9 @@ class DatabaseTest(unittest.TestCase): if col == 0: return row else: return ('%i' % (row%10))*((255-self.rows/2)+row) self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) + insert_statement = ('INSERT INTO %s VALUES (%s)' % + (self.table, + ','.join(['%s'] * len(columndefs)))) try: self.cursor.execute(insert_statement, (0, '0'*256)) @@ -169,27 +184,27 @@ class DatabaseTest(unittest.TestCase): self.fail("Over-long columns did not generate warnings/exception with executemany()") self.connection.rollback() - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute('drop table %s' % (self.table)) def test_CHAR(self): - # Character data - def generator(row,col): - return ('%i' % ((row+col) % 10)) * 255 + # Character data + def generator(row,col): + return ('%i' % ((row+col) % 10)) * 255 self.check_data_integrity( ('col1 char(255)','col2 char(255)'), generator) def test_INT(self): - # Number data - def generator(row,col): + # Number data + def generator(row,col): return row*row self.check_data_integrity( ('col1 INT',), generator) def test_DECIMAL(self): - # DECIMAL - def generator(row,col): + # DECIMAL + def generator(row,col): from decimal import Decimal return Decimal("%d.%02d" % (row, col)) self.check_data_integrity( @@ -198,78 +213,78 @@ class DatabaseTest(unittest.TestCase): def test_DATE(self): ticks = time() - def generator(row,col): + def generator(row,col): return self.db_module.DateFromTicks(ticks+row*86400-col*1313) self.check_data_integrity( - ('col1 DATE',), - generator) + ('col1 DATE',), + generator) def test_TIME(self): ticks = time() - def generator(row,col): + def generator(row,col): return self.db_module.TimeFromTicks(ticks+row*86400-col*1313) self.check_data_integrity( - ('col1 TIME',), - generator) + ('col1 TIME',), + generator) def test_DATETIME(self): ticks = time() - def generator(row,col): + def generator(row,col): return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) self.check_data_integrity( - ('col1 DATETIME',), - generator) + ('col1 DATETIME',), + generator) def test_TIMESTAMP(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 TIMESTAMP',), - generator) + def generator(row,col): + return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) + self.check_data_integrity( + ('col1 TIMESTAMP',), + generator) def test_fractional_TIMESTAMP(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0) - self.check_data_integrity( - ('col1 TIMESTAMP',), - generator) + def generator(row,col): + return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0) + self.check_data_integrity( + ('col1 TIMESTAMP',), + generator) def test_LONG(self): - def generator(row,col): - if col == 0: - return row - else: - return self.BLOBUText # 'BLOB Text ' * 1024 - self.check_data_integrity( - ('col1 INT','col2 LONG'), - generator) + def generator(row,col): + if col == 0: + return row + else: + return self.BLOBUText # 'BLOB Text ' * 1024 + self.check_data_integrity( + ('col1 INT','col2 LONG'), + generator) def test_TEXT(self): - def generator(row,col): + def generator(row,col): return self.BLOBUText # 'BLOB Text ' * 1024 self.check_data_integrity( - ('col2 TEXT',), - generator) + ('col2 TEXT',), + generator) def test_LONG_BYTE(self): - def generator(row,col): - if col == 0: - return row - else: - return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + def generator(row,col): + if col == 0: + return row + else: + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 self.check_data_integrity( - ('col1 INT','col2 LONG BYTE'), - generator) + ('col1 INT','col2 LONG BYTE'), + generator) def test_BLOB(self): - def generator(row,col): - if col == 0: - return row - else: - return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + def generator(row,col): + if col == 0: + return row + else: + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 self.check_data_integrity( - ('col1 INT','col2 BLOB'), - generator) + ('col1 INT','col2 BLOB'), + generator)