mirror of
https://github.com/PyMySQL/mysqlclient.git
synced 2025-08-15 11:10:58 +08:00
multi statements can be disabled (#500)
This commit is contained in:
@ -110,6 +110,10 @@ class Connection(_mysql.connection):
|
|||||||
:param int client_flag:
|
:param int client_flag:
|
||||||
flags to use or 0 (see MySQL docs or constants/CLIENTS.py)
|
flags to use or 0 (see MySQL docs or constants/CLIENTS.py)
|
||||||
|
|
||||||
|
:param bool multi_statements:
|
||||||
|
If True, enable multi statements for clients >= 4.1.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
:param str ssl_mode:
|
:param str ssl_mode:
|
||||||
specify the security settings for connection to the server;
|
specify the security settings for connection to the server;
|
||||||
see the MySQL documentation for more details
|
see the MySQL documentation for more details
|
||||||
@ -169,11 +173,16 @@ class Connection(_mysql.connection):
|
|||||||
self._binary_prefix = kwargs2.pop("binary_prefix", False)
|
self._binary_prefix = kwargs2.pop("binary_prefix", False)
|
||||||
|
|
||||||
client_flag = kwargs.get("client_flag", 0)
|
client_flag = kwargs.get("client_flag", 0)
|
||||||
|
|
||||||
client_version = tuple(
|
client_version = tuple(
|
||||||
[numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]]
|
[numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
multi_statements = kwargs2.pop("multi_statements", True)
|
||||||
|
if multi_statements:
|
||||||
if client_version >= (4, 1):
|
if client_version >= (4, 1):
|
||||||
client_flag |= CLIENT.MULTI_STATEMENTS
|
client_flag |= CLIENT.MULTI_STATEMENTS
|
||||||
|
|
||||||
if client_version >= (5, 0):
|
if client_version >= (5, 0):
|
||||||
client_flag |= CLIENT.MULTI_RESULTS
|
client_flag |= CLIENT.MULTI_RESULTS
|
||||||
|
|
||||||
|
26
tests/test_connection.py
Normal file
26
tests/test_connection.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from MySQLdb._exceptions import ProgrammingError
|
||||||
|
|
||||||
|
from configdb import connection_factory
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_statements_default_true():
|
||||||
|
conn = connection_factory()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute("select 17; select 2")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
assert rows == ((17,),)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_statements_false():
|
||||||
|
conn = connection_factory(multi_statements=False)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
with pytest.raises(ProgrammingError):
|
||||||
|
cursor.execute("select 17; select 2")
|
||||||
|
|
||||||
|
cursor.execute("select 17")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
assert rows == ((17,),)
|
Reference in New Issue
Block a user