mirror of
https://github.com/PyMySQL/mysqlclient.git
synced 2025-08-15 11:10:58 +08:00
318 lines
9.8 KiB
Python
318 lines
9.8 KiB
Python
#!/usr/bin/env python -O
|
|
""" Script to test database capabilities and the DB-API interface
|
|
for functionality and memory leaks.
|
|
|
|
Adapted from a script by M-A Lemburg.
|
|
|
|
"""
|
|
from time import time
|
|
import unittest
|
|
from configdb import connection_factory
|
|
|
|
|
|
class DatabaseTest(unittest.TestCase):
|
|
|
|
db_module = None
|
|
connect_args = ()
|
|
connect_kwargs = dict()
|
|
create_table_extra = ""
|
|
rows = 10
|
|
debug = False
|
|
|
|
def setUp(self):
|
|
|
|
db = connection_factory(**self.connect_kwargs)
|
|
self.connection = db
|
|
self.cursor = db.cursor()
|
|
self.BLOBUText = "".join([chr(i) for i in range(16384)])
|
|
self.BLOBBinary = self.db_module.Binary(
|
|
("".join([chr(i) for i in range(256)] * 16)).encode("latin1")
|
|
)
|
|
|
|
leak_test = True
|
|
|
|
def tearDown(self):
|
|
if self.leak_test:
|
|
import gc
|
|
|
|
del self.cursor
|
|
orphans = gc.collect()
|
|
self.failIf(
|
|
orphans, "%d orphaned objects found after deleting cursor" % orphans
|
|
)
|
|
|
|
del self.connection
|
|
orphans = gc.collect()
|
|
self.failIf(
|
|
orphans, "%d orphaned objects found after deleting connection" % orphans
|
|
)
|
|
|
|
def table_exists(self, name):
|
|
try:
|
|
self.cursor.execute("select * from %s where 1=0" % name)
|
|
except Exception:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def quote_identifier(self, ident):
|
|
return '"%s"' % ident
|
|
|
|
def new_table_name(self):
|
|
i = id(self.cursor)
|
|
while True:
|
|
name = self.quote_identifier("tb%08x" % i)
|
|
if not self.table_exists(name):
|
|
return name
|
|
i = i + 1
|
|
|
|
def create_table(self, columndefs):
|
|
|
|
"""Create a table using a list of column definitions given in
|
|
columndefs.
|
|
|
|
generator must be a function taking arguments (row_number,
|
|
col_number) returning a suitable data object for insertion
|
|
into the table.
|
|
|
|
"""
|
|
self.table = self.new_table_name()
|
|
self.cursor.execute(
|
|
"CREATE TABLE %s (%s) %s"
|
|
% (self.table, ",\n".join(columndefs), self.create_table_extra)
|
|
)
|
|
|
|
def check_data_integrity(self, columndefs, generator):
|
|
# insert
|
|
self.create_table(columndefs)
|
|
insert_statement = "INSERT INTO %s VALUES (%s)" % (
|
|
self.table,
|
|
",".join(["%s"] * len(columndefs)),
|
|
)
|
|
data = [
|
|
[generator(i, j) for j in range(len(columndefs))] for i in range(self.rows)
|
|
]
|
|
self.cursor.executemany(insert_statement, data)
|
|
self.connection.commit()
|
|
# verify
|
|
self.cursor.execute("select * from %s" % self.table)
|
|
res = self.cursor.fetchall()
|
|
self.assertEqual(len(res), self.rows)
|
|
try:
|
|
for i in range(self.rows):
|
|
for j in range(len(columndefs)):
|
|
self.assertEqual(res[i][j], generator(i, j))
|
|
finally:
|
|
if not self.debug:
|
|
self.cursor.execute("drop table %s" % (self.table))
|
|
|
|
def test_transactions(self):
|
|
columndefs = ("col1 INT", "col2 VARCHAR(255)")
|
|
|
|
def generator(row, col):
|
|
if col == 0:
|
|
return row
|
|
else:
|
|
return ("%i" % (row % 10)) * 255
|
|
|
|
self.create_table(columndefs)
|
|
insert_statement = "INSERT INTO %s VALUES (%s)" % (
|
|
self.table,
|
|
",".join(["%s"] * len(columndefs)),
|
|
)
|
|
data = [
|
|
[generator(i, j) for j in range(len(columndefs))] for i in range(self.rows)
|
|
]
|
|
self.cursor.executemany(insert_statement, data)
|
|
# verify
|
|
self.connection.commit()
|
|
self.cursor.execute("select * from %s" % self.table)
|
|
res = self.cursor.fetchall()
|
|
self.assertEqual(len(res), self.rows)
|
|
for i in range(self.rows):
|
|
for j in range(len(columndefs)):
|
|
self.assertEqual(res[i][j], generator(i, j))
|
|
delete_statement = "delete from %s where col1=%%s" % self.table
|
|
self.cursor.execute(delete_statement, (0,))
|
|
self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0))
|
|
res = self.cursor.fetchall()
|
|
self.assertFalse(res, "DELETE didn't work")
|
|
self.connection.rollback()
|
|
self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0))
|
|
res = self.cursor.fetchall()
|
|
self.assertTrue(len(res) == 1, "ROLLBACK didn't work")
|
|
self.cursor.execute("drop table %s" % (self.table))
|
|
|
|
def test_truncation(self):
|
|
columndefs = ("col1 INT", "col2 VARCHAR(255)")
|
|
|
|
def generator(row, col):
|
|
if col == 0:
|
|
return row
|
|
else:
|
|
return ("%i" % (row % 10)) * ((255 - self.rows // 2) + row)
|
|
|
|
self.create_table(columndefs)
|
|
insert_statement = "INSERT INTO %s VALUES (%s)" % (
|
|
self.table,
|
|
",".join(["%s"] * len(columndefs)),
|
|
)
|
|
|
|
try:
|
|
self.cursor.execute(insert_statement, (0, "0" * 256))
|
|
except self.connection.DataError:
|
|
pass
|
|
else:
|
|
self.fail(
|
|
"Over-long column did not generate warnings/exception with single insert" # noqa: E501
|
|
)
|
|
|
|
self.connection.rollback()
|
|
|
|
try:
|
|
for i in range(self.rows):
|
|
data = []
|
|
for j in range(len(columndefs)):
|
|
data.append(generator(i, j))
|
|
self.cursor.execute(insert_statement, tuple(data))
|
|
except self.connection.DataError:
|
|
pass
|
|
else:
|
|
self.fail(
|
|
"Over-long columns did not generate warnings/exception with execute()" # noqa: E501
|
|
)
|
|
|
|
self.connection.rollback()
|
|
|
|
try:
|
|
data = [
|
|
[generator(i, j) for j in range(len(columndefs))]
|
|
for i in range(self.rows)
|
|
]
|
|
self.cursor.executemany(insert_statement, data)
|
|
except self.connection.DataError:
|
|
pass
|
|
else:
|
|
self.fail(
|
|
"Over-long columns did not generate warnings/exception with executemany()" # noqa: E501
|
|
)
|
|
|
|
self.connection.rollback()
|
|
self.cursor.execute("drop table %s" % (self.table))
|
|
|
|
def test_CHAR(self):
|
|
# Character data
|
|
def generator(row, col):
|
|
return ("%i" % ((row + col) % 10)) * 255
|
|
|
|
self.check_data_integrity(("col1 char(255)", "col2 char(255)"), generator)
|
|
|
|
def test_INT(self):
|
|
# Number data
|
|
def generator(row, col):
|
|
return row * row
|
|
|
|
self.check_data_integrity(("col1 INT",), generator)
|
|
|
|
def test_DECIMAL(self):
|
|
# DECIMAL
|
|
from decimal import Decimal
|
|
|
|
def generator(row, col):
|
|
return Decimal("%d.%02d" % (row, col))
|
|
|
|
self.check_data_integrity(("col1 DECIMAL(5,2)",), generator)
|
|
|
|
val = Decimal("1.11111111111111119E-7")
|
|
self.cursor.execute("SELECT %s", (val,))
|
|
result = self.cursor.fetchone()[0]
|
|
self.assertEqual(result, val)
|
|
self.assertIsInstance(result, Decimal)
|
|
|
|
self.cursor.execute("SELECT %s + %s", (Decimal("0.1"), Decimal("0.2")))
|
|
result = self.cursor.fetchone()[0]
|
|
self.assertEqual(result, Decimal("0.3"))
|
|
self.assertIsInstance(result, Decimal)
|
|
|
|
def test_DATE(self):
|
|
ticks = time()
|
|
|
|
def generator(row, col):
|
|
return self.db_module.DateFromTicks(ticks + row * 86400 - col * 1313)
|
|
|
|
self.check_data_integrity(("col1 DATE",), generator)
|
|
|
|
def test_TIME(self):
|
|
ticks = time()
|
|
|
|
def generator(row, col):
|
|
return self.db_module.TimeFromTicks(ticks + row * 86400 - col * 1313)
|
|
|
|
self.check_data_integrity(("col1 TIME",), generator)
|
|
|
|
def test_DATETIME(self):
|
|
ticks = time()
|
|
|
|
def generator(row, col):
|
|
return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313)
|
|
|
|
self.check_data_integrity(("col1 DATETIME",), generator)
|
|
|
|
def test_TIMESTAMP(self):
|
|
ticks = time()
|
|
|
|
def generator(row, col):
|
|
return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313)
|
|
|
|
self.check_data_integrity(("col1 TIMESTAMP",), generator)
|
|
|
|
def test_fractional_TIMESTAMP(self):
|
|
ticks = time()
|
|
|
|
def generator(row, col):
|
|
return self.db_module.TimestampFromTicks(
|
|
ticks + row * 86400 - col * 1313 + row * 0.7 * col / 3.0
|
|
)
|
|
|
|
self.check_data_integrity(("col1 TIMESTAMP",), generator)
|
|
|
|
def test_LONG(self):
|
|
def generator(row, col):
|
|
if col == 0:
|
|
return row
|
|
else:
|
|
return self.BLOBUText # 'BLOB Text ' * 1024
|
|
|
|
self.check_data_integrity(("col1 INT", "col2 LONG"), generator)
|
|
|
|
def test_TEXT(self):
|
|
def generator(row, col):
|
|
return self.BLOBUText # 'BLOB Text ' * 1024
|
|
|
|
self.check_data_integrity(("col2 TEXT",), generator)
|
|
|
|
def test_LONG_BYTE(self):
|
|
def generator(row, col):
|
|
if col == 0:
|
|
return row
|
|
else:
|
|
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
|
|
|
|
self.check_data_integrity(("col1 INT", "col2 LONG BYTE"), generator)
|
|
|
|
def test_BLOB(self):
|
|
def generator(row, col):
|
|
if col == 0:
|
|
return row
|
|
else:
|
|
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
|
|
|
|
self.check_data_integrity(("col1 INT", "col2 BLOB"), generator)
|
|
|
|
def test_DOUBLE(self):
|
|
for val in (18014398509481982.0, 0.1):
|
|
self.cursor.execute("SELECT %s", (val,))
|
|
result = self.cursor.fetchone()[0]
|
|
self.assertEqual(result, val)
|
|
self.assertIsInstance(result, float)
|