Files
mysqlclient/tests/capabilities.py
Inada Naoki d56b0b7f8c black 20.8b1
2020-12-04 12:22:51 +09:00

318 lines
9.8 KiB
Python

#!/usr/bin/env python -O
""" Script to test database capabilities and the DB-API interface
for functionality and memory leaks.
Adapted from a script by M-A Lemburg.
"""
from time import time
import unittest
from configdb import connection_factory
class DatabaseTest(unittest.TestCase):
db_module = None
connect_args = ()
connect_kwargs = dict()
create_table_extra = ""
rows = 10
debug = False
def setUp(self):
db = connection_factory(**self.connect_kwargs)
self.connection = db
self.cursor = db.cursor()
self.BLOBUText = "".join([chr(i) for i in range(16384)])
self.BLOBBinary = self.db_module.Binary(
("".join([chr(i) for i in range(256)] * 16)).encode("latin1")
)
leak_test = True
def tearDown(self):
if self.leak_test:
import gc
del self.cursor
orphans = gc.collect()
self.failIf(
orphans, "%d orphaned objects found after deleting cursor" % orphans
)
del self.connection
orphans = gc.collect()
self.failIf(
orphans, "%d orphaned objects found after deleting connection" % orphans
)
def table_exists(self, name):
try:
self.cursor.execute("select * from %s where 1=0" % name)
except Exception:
return False
else:
return True
def quote_identifier(self, ident):
return '"%s"' % ident
def new_table_name(self):
i = id(self.cursor)
while True:
name = self.quote_identifier("tb%08x" % i)
if not self.table_exists(name):
return name
i = i + 1
def create_table(self, columndefs):
"""Create a table using a list of column definitions given in
columndefs.
generator must be a function taking arguments (row_number,
col_number) returning a suitable data object for insertion
into the table.
"""
self.table = self.new_table_name()
self.cursor.execute(
"CREATE TABLE %s (%s) %s"
% (self.table, ",\n".join(columndefs), self.create_table_extra)
)
def check_data_integrity(self, columndefs, generator):
# insert
self.create_table(columndefs)
insert_statement = "INSERT INTO %s VALUES (%s)" % (
self.table,
",".join(["%s"] * len(columndefs)),
)
data = [
[generator(i, j) for j in range(len(columndefs))] for i in range(self.rows)
]
self.cursor.executemany(insert_statement, data)
self.connection.commit()
# verify
self.cursor.execute("select * from %s" % self.table)
res = self.cursor.fetchall()
self.assertEqual(len(res), self.rows)
try:
for i in range(self.rows):
for j in range(len(columndefs)):
self.assertEqual(res[i][j], generator(i, j))
finally:
if not self.debug:
self.cursor.execute("drop table %s" % (self.table))
def test_transactions(self):
columndefs = ("col1 INT", "col2 VARCHAR(255)")
def generator(row, col):
if col == 0:
return row
else:
return ("%i" % (row % 10)) * 255
self.create_table(columndefs)
insert_statement = "INSERT INTO %s VALUES (%s)" % (
self.table,
",".join(["%s"] * len(columndefs)),
)
data = [
[generator(i, j) for j in range(len(columndefs))] for i in range(self.rows)
]
self.cursor.executemany(insert_statement, data)
# verify
self.connection.commit()
self.cursor.execute("select * from %s" % self.table)
res = self.cursor.fetchall()
self.assertEqual(len(res), self.rows)
for i in range(self.rows):
for j in range(len(columndefs)):
self.assertEqual(res[i][j], generator(i, j))
delete_statement = "delete from %s where col1=%%s" % self.table
self.cursor.execute(delete_statement, (0,))
self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0))
res = self.cursor.fetchall()
self.assertFalse(res, "DELETE didn't work")
self.connection.rollback()
self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0))
res = self.cursor.fetchall()
self.assertTrue(len(res) == 1, "ROLLBACK didn't work")
self.cursor.execute("drop table %s" % (self.table))
def test_truncation(self):
columndefs = ("col1 INT", "col2 VARCHAR(255)")
def generator(row, col):
if col == 0:
return row
else:
return ("%i" % (row % 10)) * ((255 - self.rows // 2) + row)
self.create_table(columndefs)
insert_statement = "INSERT INTO %s VALUES (%s)" % (
self.table,
",".join(["%s"] * len(columndefs)),
)
try:
self.cursor.execute(insert_statement, (0, "0" * 256))
except self.connection.DataError:
pass
else:
self.fail(
"Over-long column did not generate warnings/exception with single insert" # noqa: E501
)
self.connection.rollback()
try:
for i in range(self.rows):
data = []
for j in range(len(columndefs)):
data.append(generator(i, j))
self.cursor.execute(insert_statement, tuple(data))
except self.connection.DataError:
pass
else:
self.fail(
"Over-long columns did not generate warnings/exception with execute()" # noqa: E501
)
self.connection.rollback()
try:
data = [
[generator(i, j) for j in range(len(columndefs))]
for i in range(self.rows)
]
self.cursor.executemany(insert_statement, data)
except self.connection.DataError:
pass
else:
self.fail(
"Over-long columns did not generate warnings/exception with executemany()" # noqa: E501
)
self.connection.rollback()
self.cursor.execute("drop table %s" % (self.table))
def test_CHAR(self):
# Character data
def generator(row, col):
return ("%i" % ((row + col) % 10)) * 255
self.check_data_integrity(("col1 char(255)", "col2 char(255)"), generator)
def test_INT(self):
# Number data
def generator(row, col):
return row * row
self.check_data_integrity(("col1 INT",), generator)
def test_DECIMAL(self):
# DECIMAL
from decimal import Decimal
def generator(row, col):
return Decimal("%d.%02d" % (row, col))
self.check_data_integrity(("col1 DECIMAL(5,2)",), generator)
val = Decimal("1.11111111111111119E-7")
self.cursor.execute("SELECT %s", (val,))
result = self.cursor.fetchone()[0]
self.assertEqual(result, val)
self.assertIsInstance(result, Decimal)
self.cursor.execute("SELECT %s + %s", (Decimal("0.1"), Decimal("0.2")))
result = self.cursor.fetchone()[0]
self.assertEqual(result, Decimal("0.3"))
self.assertIsInstance(result, Decimal)
def test_DATE(self):
ticks = time()
def generator(row, col):
return self.db_module.DateFromTicks(ticks + row * 86400 - col * 1313)
self.check_data_integrity(("col1 DATE",), generator)
def test_TIME(self):
ticks = time()
def generator(row, col):
return self.db_module.TimeFromTicks(ticks + row * 86400 - col * 1313)
self.check_data_integrity(("col1 TIME",), generator)
def test_DATETIME(self):
ticks = time()
def generator(row, col):
return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313)
self.check_data_integrity(("col1 DATETIME",), generator)
def test_TIMESTAMP(self):
ticks = time()
def generator(row, col):
return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313)
self.check_data_integrity(("col1 TIMESTAMP",), generator)
def test_fractional_TIMESTAMP(self):
ticks = time()
def generator(row, col):
return self.db_module.TimestampFromTicks(
ticks + row * 86400 - col * 1313 + row * 0.7 * col / 3.0
)
self.check_data_integrity(("col1 TIMESTAMP",), generator)
def test_LONG(self):
def generator(row, col):
if col == 0:
return row
else:
return self.BLOBUText # 'BLOB Text ' * 1024
self.check_data_integrity(("col1 INT", "col2 LONG"), generator)
def test_TEXT(self):
def generator(row, col):
return self.BLOBUText # 'BLOB Text ' * 1024
self.check_data_integrity(("col2 TEXT",), generator)
def test_LONG_BYTE(self):
def generator(row, col):
if col == 0:
return row
else:
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
self.check_data_integrity(("col1 INT", "col2 LONG BYTE"), generator)
def test_BLOB(self):
def generator(row, col):
if col == 0:
return row
else:
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
self.check_data_integrity(("col1 INT", "col2 BLOB"), generator)
def test_DOUBLE(self):
for val in (18014398509481982.0, 0.1):
self.cursor.execute("SELECT %s", (val,))
result = self.cursor.fetchone()[0]
self.assertEqual(result, val)
self.assertIsInstance(result, float)