From c93adbedd4ae5f6761231a4c2590d46c78c18a4c Mon Sep 17 00:00:00 2001 From: adustman Date: Thu, 2 Mar 2006 04:19:32 +0000 Subject: [PATCH] * Add metadata.cfg and site.cfg to MANIFEST.in so they get packaged * Remove version_info from metadata before calling setup() to avoid complaints * Fix cursor.callproc() as good as can be fixed. * Improve/fix various tests for stored procedures. --- MySQLdb/MANIFEST.in | 4 + MySQLdb/MySQLdb/cursors.py | 19 ++- MySQLdb/metadata.cfg | 4 +- MySQLdb/setup.py | 2 + MySQLdb/test_MySQLdb_capabilities.py | 82 ++++++++++ MySQLdb/test_MySQLdb_dbapi20.py | 59 ++++++- MySQLdb/test_capabilities.py | 227 +++++++++++++++++++++++++++ 7 files changed, 385 insertions(+), 12 deletions(-) create mode 100644 MySQLdb/test_MySQLdb_capabilities.py create mode 100644 MySQLdb/test_capabilities.py diff --git a/MySQLdb/MANIFEST.in b/MySQLdb/MANIFEST.in index 2f1edda..a146fe4 100644 --- a/MySQLdb/MANIFEST.in +++ b/MySQLdb/MANIFEST.in @@ -7,3 +7,7 @@ include GPL include pymemcompat.h include dbapi20.py include test_MySQLdb_dbapi20.py +include test_capabilities.py +include test_MySQLdb_capabilities.py +include metadata.cfg +include site.cfg diff --git a/MySQLdb/MySQLdb/cursors.py b/MySQLdb/MySQLdb/cursors.py index 8912a28..e0c37f7 100644 --- a/MySQLdb/MySQLdb/cursors.py +++ b/MySQLdb/MySQLdb/cursors.py @@ -199,7 +199,7 @@ class BaseCursor(object): self._warning_check() return r - def callproc(self, procname, args): + def callproc(self, procname, args=()): """Execute stored procedure procname with args @@ -220,6 +220,13 @@ class BaseCursor(object): (from zero). Once all result sets generated by the procedure have been fetched, you can issue a SELECT @_procname_0, ... query using .execute() to get any OUT or INOUT values. + + Compatibility warning: The act of calling a stored procedure + itself creates an empty result set. This appears after any + result sets generated by the procedure. This is non-standard + behavior with respect to the DB-API. Be sure to use nextset() + to advance through all result sets; otherwise you may get + disconnected. """ from types import UnicodeType @@ -229,16 +236,16 @@ class BaseCursor(object): db.literal(arg)) if type(q) is UnicodeType: q = q.encode(db.charset) - db.query(q) - self._do_get_result() - + self._query(q) + self.nextset() + q = "CALL %s(%s)" % (procname, ','.join(['@_%s_%d' % (procname, i) for i in range(len(args))])) if type(q) is UnicodeType: q = q.encode(db.charset) - db.query(q) - self._do_get_result() + self._query(q) + self._warning_check() return args def _do_query(self, q): diff --git a/MySQLdb/metadata.cfg b/MySQLdb/metadata.cfg index a6c3137..b139101 100644 --- a/MySQLdb/metadata.cfg +++ b/MySQLdb/metadata.cfg @@ -1,6 +1,6 @@ [metadata] -version: 1.2.1c4 -version_info: (1,2,1,'gamma',4) +version: 1.2.1c5 +version_info: (1,2,1,'gamma',5) description: Python interface to MySQL long_description: ========================= diff --git a/MySQLdb/setup.py b/MySQLdb/setup.py index a6500c3..64cfaae 100644 --- a/MySQLdb/setup.py +++ b/MySQLdb/setup.py @@ -86,6 +86,8 @@ __version__ = "%(version)s" """ % metadata) rel.close() +del metadata['version_info'] + ext_mysql_metadata = dict( name="_mysql", include_dirs=include_dirs, diff --git a/MySQLdb/test_MySQLdb_capabilities.py b/MySQLdb/test_MySQLdb_capabilities.py new file mode 100644 index 0000000..bcda1b7 --- /dev/null +++ b/MySQLdb/test_MySQLdb_capabilities.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +import test_capabilities +import unittest +import MySQLdb +import warnings + +warnings.filterwarnings('error') + +class test_MySQLdb(test_capabilities.DatabaseTest): + + db_module = MySQLdb + connect_args = () + connect_kwargs = dict(db='test', read_default_file='~/.my.cnf', + use_unicode=True) + create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8" + + def quote_identifier(self, ident): + return "`%s`" % ident + + def test_TIME(self): + from datetime import timedelta + def generator(row,col): + return timedelta(0, row*8000) + self.check_data_integrity( + ('col1 TIME',), + generator) + + def test_TINYINT(self): + # Number data + def generator(row,col): + v = (row*row) % 256 + if v > 127: + v = v-256 + return v + self.check_data_integrity( + ('col1 TINYINT',), + generator) + + def test_SET(self): + if True: return + things = 'ash birch cedar larch pine'.split() + def generator(row, col): + from sets import Set + s = Set() + for i in range(len(things)): + if (row >> i) & 1: + s.add(things[i]) + return s + self.debug = True + self.check_data_integrity( + ('col1 SET(%s)' % ','.join(["'%s'" % t for t in things]),), + generator) + + def test_stored_procedures(self): + db = self.connection + c = self.cursor + self.create_table(('pos INT', 'tree CHAR(20)')) + c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table, + list(enumerate('ash birch cedar larch pine'.split()))) + db.commit() + + c.execute(""" + CREATE PROCEDURE test_sp(IN t VARCHAR(255)) + BEGIN + SELECT pos FROM %s WHERE tree = t; + END + """ % self.table) + db.commit() + + c.callproc('test_sp', ('larch',)) + rows = c.fetchall() + self.assertEquals(len(rows), 1) + self.assertEquals(rows[0][0], 3) + c.nextset() + + c.execute("DROP PROCEDURE test_sp") + c.execute('drop table %s' % (self.table)) + + +if __name__ == '__main__': + unittest.main() + print '''"Huh-huh, he said 'unit'." -- Butthead''' diff --git a/MySQLdb/test_MySQLdb_dbapi20.py b/MySQLdb/test_MySQLdb_dbapi20.py index 81aff7d..7c4bd15 100644 --- a/MySQLdb/test_MySQLdb_dbapi20.py +++ b/MySQLdb/test_MySQLdb_dbapi20.py @@ -140,11 +140,62 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): con.close() def test_callproc(self): + pass # performed in test_MySQL_capabilities + + def help_nextset_setUp(self,cur): + ''' Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + ''' + sql=""" + create procedure deleteme() + begin + select count(*) from %(tp)sbooze; + select name from %(tp)sbooze; + end + """ % dict(tp=self.table_prefix) + cur.execute(sql) + + def help_nextset_tearDown(self,cur): + 'If cleaning up is needed after nextSetTest' + cur.execute("drop procedure deleteme") + + def test_nextset(self): + from warnings import warn + con = self._connect() try: - dbapi20.DatabaseAPI20Test.test_callproc(self) - except MySQLdb.ProgrammingError: - # not supported by server - pass + cur = con.cursor() + if not hasattr(cur,'nextset'): + return + + try: + self.executeDDL1(cur) + sql=self._populate() + for sql in self._populate(): + cur.execute(sql) + + self.help_nextset_setUp(cur) + + cur.callproc('deleteme') + numberofrows=cur.fetchone() + assert numberofrows[0]== len(self.samples) + assert cur.nextset() + names=cur.fetchall() + assert len(names) == len(self.samples) + s=cur.nextset() + if s: + empty = cur.fetchall() + self.assertEquals(len(empty), 0, + "non-empty result set after other result sets") + warn("Incompatibility: MySQL returns an empty result set for the CALL itself", + Warning) + #assert s == None,'No more return sets, should return None' + finally: + self.help_nextset_tearDown(cur) + + finally: + con.close() + if __name__ == '__main__': unittest.main() diff --git a/MySQLdb/test_capabilities.py b/MySQLdb/test_capabilities.py new file mode 100644 index 0000000..6b9bf35 --- /dev/null +++ b/MySQLdb/test_capabilities.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python -O +""" Script to test database capabilities and the DB-API interface + for functionality and memory leaks. + + Adapted from a script by M-A Lemburg. + +""" +from time import time +import array +import unittest + + +class DatabaseTest(unittest.TestCase): + + db_module = None + connect_args = () + connect_kwargs = dict() + create_table_extra = '' + rows = 10 + debug = False + + def setUp(self): + db = self.db_module.connect(*self.connect_args, **self.connect_kwargs) + self.connection = db + self.cursor = db.cursor() + self.BLOBText = ''.join([chr(i) for i in range(256)] * 100); + self.BLOBUText = u''.join([unichr(i) for i in range(16384)]) + self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16)) + + def table_exists(self, name): + try: + self.cursor.execute('select * from %s where 1=0' % name) + except: + return False + else: + return True + + def quote_identifier(self, ident): + return '"%s"' % ident + + def new_table_name(self): + i = id(self.cursor) + while True: + name = self.quote_identifier('tb%08x' % i) + if not self.table_exists(name): + return name + i = i + 1 + + def create_table(self, columndefs): + + """ Create a table using a list of column definitions given in + columndefs. + + generator must be a function taking arguments (row_number, + col_number) returning a suitable data object for insertion + into the table. + + """ + self.table = self.new_table_name() + self.cursor.execute('CREATE TABLE %s (%s) %s' % + (self.table, + ',\n'.join(columndefs), + self.create_table_extra)) + + def check_data_integrity(self, columndefs, generator): + # insert + self.create_table(columndefs) + insert_statement = ('INSERT INTO %s VALUES (%s)' % + (self.table, + ','.join(['%s'] * len(columndefs)))) + data = [ [ generator(i,j) for j in range(len(columndefs)) ] + for i in range(self.rows) ] + if self.debug: + print data + self.cursor.executemany(insert_statement, data) + self.connection.commit() + # verify + self.cursor.execute('select * from %s' % self.table) + l = self.cursor.fetchall() + if self.debug: + print l + self.assertEquals(len(l), self.rows) + try: + for i in range(self.rows): + for j in range(len(columndefs)): + self.assertEquals(l[i][j], generator(i,j)) + finally: + if not self.debug: + self.cursor.execute('drop table %s' % (self.table)) + + def test_transactions(self): + columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') + def generator(row, col): + if col == 0: return row + else: return ('%i' % (row%10))*255 + self.create_table(columndefs) + insert_statement = ('INSERT INTO %s VALUES (%s)' % + (self.table, + ','.join(['%s'] * len(columndefs)))) + for i in range(self.rows): + data = [] + for j in range(len(columndefs)): + data.append(generator(i,j)) + self.cursor.execute(insert_statement,tuple(data)) + # verify + self.connection.commit() + self.cursor.execute('select * from %s' % self.table) + l = self.cursor.fetchall() + self.assertEquals(len(l), self.rows) + for i in range(self.rows): + for j in range(len(columndefs)): + self.assertEquals(l[i][j], generator(i,j)) + delete_statement = 'delete from %s where col1=%%s' % self.table + self.cursor.execute(delete_statement, (0,)) + self.cursor.execute('select col1 from %s where col1=%s' % \ + (self.table, 0)) + l = self.cursor.fetchall() + self.failIf(l, "DELETE didn't work") + self.connection.rollback() + self.cursor.execute('select col1 from %s where col1=%s' % \ + (self.table, 0)) + l = self.cursor.fetchall() + self.failUnless(len(l) == 1, "ROLLBACK didn't work") + self.cursor.execute('drop table %s' % (self.table)) + + def test_CHAR(self): + # Character data + def generator(row,col): + return ('%i' % ((row+col) % 10)) * 255 + self.check_data_integrity( + ('col1 char(255)','col2 char(255)'), + generator) + + def test_INT(self): + # Number data + def generator(row,col): + return row*row + self.check_data_integrity( + ('col1 INT',), + generator) + + def test_DECIMAL(self): + # DECIMAL + def generator(row,col): + from decimal import Decimal + return Decimal("%d.%02d" % (row, col)) + self.check_data_integrity( + ('col1 DECIMAL(5,2)',), + generator) + + def test_DATE(self): + ticks = time() + def generator(row,col): + return self.db_module.DateFromTicks(ticks+row*86400-col*1313) + self.check_data_integrity( + ('col1 DATE',), + generator) + + def test_TIME(self): + ticks = time() + def generator(row,col): + return self.db_module.TimeFromTicks(ticks+row*86400-col*1313) + self.check_data_integrity( + ('col1 TIME',), + generator) + + def test_DATETIME(self): + ticks = time() + def generator(row,col): + return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) + self.check_data_integrity( + ('col1 DATETIME',), + generator) + + def test_TIMESTAMP(self): + ticks = time() + def generator(row,col): + return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) + self.check_data_integrity( + ('col1 TIMESTAMP',), + generator) + + def test_fractional_TIMESTAMP(self): + ticks = time() + def generator(row,col): + return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0) + self.check_data_integrity( + ('col1 TIMESTAMP',), + generator) + + def test_LONG(self): + def generator(row,col): + if col == 0: + return row + else: + return self.BLOBUText # 'BLOB Text ' * 1024 + self.check_data_integrity( + ('col1 INT','col2 LONG'), + generator) + + def test_TEXT(self): + def generator(row,col): + return self.BLOBUText # 'BLOB Text ' * 1024 + self.check_data_integrity( + ('col2 TEXT',), + generator) + + def test_LONG_BYTE(self): + def generator(row,col): + if col == 0: + return row + else: + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + self.check_data_integrity( + ('col1 INT','col2 LONG BYTE'), + generator) + + def test_BLOB(self): + def generator(row,col): + if col == 0: + return row + else: + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + self.check_data_integrity( + ('col1 INT','col2 BLOB'), + generator) +