* Add metadata.cfg and site.cfg to MANIFEST.in so they get packaged

* Remove version_info from metadata before calling setup() to avoid
  complaints

* Fix cursor.callproc() as good as can be fixed.

* Improve/fix various tests for stored procedures.
This commit is contained in:
adustman
2006-03-02 04:19:32 +00:00
parent 6024695c34
commit c93adbedd4
7 changed files with 385 additions and 12 deletions

View File

@ -7,3 +7,7 @@ include GPL
include pymemcompat.h
include dbapi20.py
include test_MySQLdb_dbapi20.py
include test_capabilities.py
include test_MySQLdb_capabilities.py
include metadata.cfg
include site.cfg

View File

@ -199,7 +199,7 @@ class BaseCursor(object):
self._warning_check()
return r
def callproc(self, procname, args):
def callproc(self, procname, args=()):
"""Execute stored procedure procname with args
@ -220,6 +220,13 @@ class BaseCursor(object):
(from zero). Once all result sets generated by the procedure
have been fetched, you can issue a SELECT @_procname_0, ...
query using .execute() to get any OUT or INOUT values.
Compatibility warning: The act of calling a stored procedure
itself creates an empty result set. This appears after any
result sets generated by the procedure. This is non-standard
behavior with respect to the DB-API. Be sure to use nextset()
to advance through all result sets; otherwise you may get
disconnected.
"""
from types import UnicodeType
@ -229,16 +236,16 @@ class BaseCursor(object):
db.literal(arg))
if type(q) is UnicodeType:
q = q.encode(db.charset)
db.query(q)
self._do_get_result()
self._query(q)
self.nextset()
q = "CALL %s(%s)" % (procname,
','.join(['@_%s_%d' % (procname, i)
for i in range(len(args))]))
if type(q) is UnicodeType:
q = q.encode(db.charset)
db.query(q)
self._do_get_result()
self._query(q)
self._warning_check()
return args
def _do_query(self, q):

View File

@ -1,6 +1,6 @@
[metadata]
version: 1.2.1c4
version_info: (1,2,1,'gamma',4)
version: 1.2.1c5
version_info: (1,2,1,'gamma',5)
description: Python interface to MySQL
long_description:
=========================

View File

@ -86,6 +86,8 @@ __version__ = "%(version)s"
""" % metadata)
rel.close()
del metadata['version_info']
ext_mysql_metadata = dict(
name="_mysql",
include_dirs=include_dirs,

View File

@ -0,0 +1,82 @@
#!/usr/bin/env python
import test_capabilities
import unittest
import MySQLdb
import warnings
warnings.filterwarnings('error')
class test_MySQLdb(test_capabilities.DatabaseTest):
db_module = MySQLdb
connect_args = ()
connect_kwargs = dict(db='test', read_default_file='~/.my.cnf',
use_unicode=True)
create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
def quote_identifier(self, ident):
return "`%s`" % ident
def test_TIME(self):
from datetime import timedelta
def generator(row,col):
return timedelta(0, row*8000)
self.check_data_integrity(
('col1 TIME',),
generator)
def test_TINYINT(self):
# Number data
def generator(row,col):
v = (row*row) % 256
if v > 127:
v = v-256
return v
self.check_data_integrity(
('col1 TINYINT',),
generator)
def test_SET(self):
if True: return
things = 'ash birch cedar larch pine'.split()
def generator(row, col):
from sets import Set
s = Set()
for i in range(len(things)):
if (row >> i) & 1:
s.add(things[i])
return s
self.debug = True
self.check_data_integrity(
('col1 SET(%s)' % ','.join(["'%s'" % t for t in things]),),
generator)
def test_stored_procedures(self):
db = self.connection
c = self.cursor
self.create_table(('pos INT', 'tree CHAR(20)'))
c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table,
list(enumerate('ash birch cedar larch pine'.split())))
db.commit()
c.execute("""
CREATE PROCEDURE test_sp(IN t VARCHAR(255))
BEGIN
SELECT pos FROM %s WHERE tree = t;
END
""" % self.table)
db.commit()
c.callproc('test_sp', ('larch',))
rows = c.fetchall()
self.assertEquals(len(rows), 1)
self.assertEquals(rows[0][0], 3)
c.nextset()
c.execute("DROP PROCEDURE test_sp")
c.execute('drop table %s' % (self.table))
if __name__ == '__main__':
unittest.main()
print '''"Huh-huh, he said 'unit'." -- Butthead'''

View File

@ -140,11 +140,62 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
con.close()
def test_callproc(self):
pass # performed in test_MySQL_capabilities
def help_nextset_setUp(self,cur):
''' Should create a procedure called deleteme
that returns two result sets, first the
number of rows in booze then "name from booze"
'''
sql="""
create procedure deleteme()
begin
select count(*) from %(tp)sbooze;
select name from %(tp)sbooze;
end
""" % dict(tp=self.table_prefix)
cur.execute(sql)
def help_nextset_tearDown(self,cur):
'If cleaning up is needed after nextSetTest'
cur.execute("drop procedure deleteme")
def test_nextset(self):
from warnings import warn
con = self._connect()
try:
dbapi20.DatabaseAPI20Test.test_callproc(self)
except MySQLdb.ProgrammingError:
# not supported by server
pass
cur = con.cursor()
if not hasattr(cur,'nextset'):
return
try:
self.executeDDL1(cur)
sql=self._populate()
for sql in self._populate():
cur.execute(sql)
self.help_nextset_setUp(cur)
cur.callproc('deleteme')
numberofrows=cur.fetchone()
assert numberofrows[0]== len(self.samples)
assert cur.nextset()
names=cur.fetchall()
assert len(names) == len(self.samples)
s=cur.nextset()
if s:
empty = cur.fetchall()
self.assertEquals(len(empty), 0,
"non-empty result set after other result sets")
warn("Incompatibility: MySQL returns an empty result set for the CALL itself",
Warning)
#assert s == None,'No more return sets, should return None'
finally:
self.help_nextset_tearDown(cur)
finally:
con.close()
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,227 @@
#!/usr/bin/env python -O
""" Script to test database capabilities and the DB-API interface
for functionality and memory leaks.
Adapted from a script by M-A Lemburg.
"""
from time import time
import array
import unittest
class DatabaseTest(unittest.TestCase):
db_module = None
connect_args = ()
connect_kwargs = dict()
create_table_extra = ''
rows = 10
debug = False
def setUp(self):
db = self.db_module.connect(*self.connect_args, **self.connect_kwargs)
self.connection = db
self.cursor = db.cursor()
self.BLOBText = ''.join([chr(i) for i in range(256)] * 100);
self.BLOBUText = u''.join([unichr(i) for i in range(16384)])
self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16))
def table_exists(self, name):
try:
self.cursor.execute('select * from %s where 1=0' % name)
except:
return False
else:
return True
def quote_identifier(self, ident):
return '"%s"' % ident
def new_table_name(self):
i = id(self.cursor)
while True:
name = self.quote_identifier('tb%08x' % i)
if not self.table_exists(name):
return name
i = i + 1
def create_table(self, columndefs):
""" Create a table using a list of column definitions given in
columndefs.
generator must be a function taking arguments (row_number,
col_number) returning a suitable data object for insertion
into the table.
"""
self.table = self.new_table_name()
self.cursor.execute('CREATE TABLE %s (%s) %s' %
(self.table,
',\n'.join(columndefs),
self.create_table_extra))
def check_data_integrity(self, columndefs, generator):
# insert
self.create_table(columndefs)
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
for i in range(self.rows) ]
if self.debug:
print data
self.cursor.executemany(insert_statement, data)
self.connection.commit()
# verify
self.cursor.execute('select * from %s' % self.table)
l = self.cursor.fetchall()
if self.debug:
print l
self.assertEquals(len(l), self.rows)
try:
for i in range(self.rows):
for j in range(len(columndefs)):
self.assertEquals(l[i][j], generator(i,j))
finally:
if not self.debug:
self.cursor.execute('drop table %s' % (self.table))
def test_transactions(self):
columndefs = ( 'col1 INT', 'col2 VARCHAR(255)')
def generator(row, col):
if col == 0: return row
else: return ('%i' % (row%10))*255
self.create_table(columndefs)
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
for i in range(self.rows):
data = []
for j in range(len(columndefs)):
data.append(generator(i,j))
self.cursor.execute(insert_statement,tuple(data))
# verify
self.connection.commit()
self.cursor.execute('select * from %s' % self.table)
l = self.cursor.fetchall()
self.assertEquals(len(l), self.rows)
for i in range(self.rows):
for j in range(len(columndefs)):
self.assertEquals(l[i][j], generator(i,j))
delete_statement = 'delete from %s where col1=%%s' % self.table
self.cursor.execute(delete_statement, (0,))
self.cursor.execute('select col1 from %s where col1=%s' % \
(self.table, 0))
l = self.cursor.fetchall()
self.failIf(l, "DELETE didn't work")
self.connection.rollback()
self.cursor.execute('select col1 from %s where col1=%s' % \
(self.table, 0))
l = self.cursor.fetchall()
self.failUnless(len(l) == 1, "ROLLBACK didn't work")
self.cursor.execute('drop table %s' % (self.table))
def test_CHAR(self):
# Character data
def generator(row,col):
return ('%i' % ((row+col) % 10)) * 255
self.check_data_integrity(
('col1 char(255)','col2 char(255)'),
generator)
def test_INT(self):
# Number data
def generator(row,col):
return row*row
self.check_data_integrity(
('col1 INT',),
generator)
def test_DECIMAL(self):
# DECIMAL
def generator(row,col):
from decimal import Decimal
return Decimal("%d.%02d" % (row, col))
self.check_data_integrity(
('col1 DECIMAL(5,2)',),
generator)
def test_DATE(self):
ticks = time()
def generator(row,col):
return self.db_module.DateFromTicks(ticks+row*86400-col*1313)
self.check_data_integrity(
('col1 DATE',),
generator)
def test_TIME(self):
ticks = time()
def generator(row,col):
return self.db_module.TimeFromTicks(ticks+row*86400-col*1313)
self.check_data_integrity(
('col1 TIME',),
generator)
def test_DATETIME(self):
ticks = time()
def generator(row,col):
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
self.check_data_integrity(
('col1 DATETIME',),
generator)
def test_TIMESTAMP(self):
ticks = time()
def generator(row,col):
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
self.check_data_integrity(
('col1 TIMESTAMP',),
generator)
def test_fractional_TIMESTAMP(self):
ticks = time()
def generator(row,col):
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0)
self.check_data_integrity(
('col1 TIMESTAMP',),
generator)
def test_LONG(self):
def generator(row,col):
if col == 0:
return row
else:
return self.BLOBUText # 'BLOB Text ' * 1024
self.check_data_integrity(
('col1 INT','col2 LONG'),
generator)
def test_TEXT(self):
def generator(row,col):
return self.BLOBUText # 'BLOB Text ' * 1024
self.check_data_integrity(
('col2 TEXT',),
generator)
def test_LONG_BYTE(self):
def generator(row,col):
if col == 0:
return row
else:
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
self.check_data_integrity(
('col1 INT','col2 LONG BYTE'),
generator)
def test_BLOB(self):
def generator(row,col):
if col == 0:
return row
else:
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
self.check_data_integrity(
('col1 INT','col2 BLOB'),
generator)