Add collation option (#564)

Fixes #563
This commit is contained in:
Vince Salvino
2023-05-08 22:45:28 -04:00
committed by GitHub
parent aed1dd2632
commit df52e237b3
3 changed files with 60 additions and 2 deletions

View File

@ -97,6 +97,14 @@ class Connection(_mysql.connection):
If supplied, the connection character set will be changed If supplied, the connection character set will be changed
to this character set. to this character set.
:param str collation:
If ``charset`` and ``collation`` are both supplied, the
character set and collation for the current connection
will be set.
If omitted, empty string, or None, the default collation
for the ``charset`` is implied.
:param str auth_plugin: :param str auth_plugin:
If supplied, the connection default authentication plugin will be If supplied, the connection default authentication plugin will be
changed to this value. Example values: changed to this value. Example values:
@ -167,6 +175,7 @@ class Connection(_mysql.connection):
cursorclass = kwargs2.pop("cursorclass", self.default_cursor) cursorclass = kwargs2.pop("cursorclass", self.default_cursor)
charset = kwargs2.get("charset", "") charset = kwargs2.get("charset", "")
collation = kwargs2.pop("collation", "")
use_unicode = kwargs2.pop("use_unicode", True) use_unicode = kwargs2.pop("use_unicode", True)
sql_mode = kwargs2.pop("sql_mode", "") sql_mode = kwargs2.pop("sql_mode", "")
self._binary_prefix = kwargs2.pop("binary_prefix", False) self._binary_prefix = kwargs2.pop("binary_prefix", False)
@ -193,7 +202,7 @@ class Connection(_mysql.connection):
if not charset: if not charset:
charset = self.character_set_name() charset = self.character_set_name()
self.set_character_set(charset) self.set_character_set(charset, collation)
if sql_mode: if sql_mode:
self.set_sql_mode(sql_mode) self.set_sql_mode(sql_mode)
@ -285,10 +294,13 @@ class Connection(_mysql.connection):
""" """
self.query(b"BEGIN") self.query(b"BEGIN")
def set_character_set(self, charset): def set_character_set(self, charset, collation=None):
"""Set the connection character set to charset.""" """Set the connection character set to charset."""
super().set_character_set(charset) super().set_character_set(charset)
self.encoding = _charset_to_encoding.get(charset, charset) self.encoding = _charset_to_encoding.get(charset, charset)
if collation:
self.query("SET NAMES %s COLLATE %s" % (charset, collation))
self.store_result()
def set_sql_mode(self, sql_mode): def set_sql_mode(self, sql_mode):
"""Set the connection sql_mode. See MySQL documentation for """Set the connection sql_mode. See MySQL documentation for

View File

@ -348,6 +348,22 @@ connect(parameters...)
*This must be a keyword parameter.* *This must be a keyword parameter.*
collation
If ``charset`` and ``collation`` are both supplied, the
character set and collation for the current connection
will be set.
If omitted, empty string, or None, the default collation
for the ``charset`` is implied by the database server.
To learn more about the quiddities of character sets and
collations, consult the `MySQL docs
<https://dev.mysql.com/doc/refman/8.0/en/charset.html>`_
and `MariaDB docs
<https://mariadb.com/kb/en/character-sets/>`_
*This must be a keyword parameter.*
sql_mode sql_mode
If present, the session SQL mode will be set to the given If present, the session SQL mode will be set to the given
string. For more information on sql_mode, see the MySQL string. For more information on sql_mode, see the MySQL

View File

@ -114,3 +114,33 @@ class CoreAPI(unittest.TestCase):
with connection_factory() as conn: with connection_factory() as conn:
self.assertFalse(conn.closed) self.assertFalse(conn.closed)
self.assertTrue(conn.closed) self.assertTrue(conn.closed)
class TestCollation(unittest.TestCase):
"""Test charset and collation connection options."""
def setUp(self):
# Initialize a connection with a non-default character set and
# collation.
self.conn = connection_factory(
charset="utf8mb4",
collation="utf8mb4_esperanto_ci",
)
def tearDown(self):
self.conn.close()
def test_charset_collation(self):
c = self.conn.cursor()
c.execute(
"""
SHOW VARIABLES WHERE
Variable_Name="character_set_connection" OR
Variable_Name="collation_connection";
"""
)
row = c.fetchall()
charset = row[0][1]
collation = row[1][1]
self.assertEqual(charset, "utf8mb4")
self.assertEqual(collation, "utf8mb4_esperanto_ci")