mirror of
https://github.com/PyMySQL/mysqlclient.git
synced 2025-08-15 02:54:29 +08:00
Use _binary prefix for bytes/bytearray parameters (#140)
- Based on #106 but now disabled by default - Can be enabled via 'binary_prefix' connection parameter - Added unit tests to verify behaviour
This commit is contained in:

committed by
INADA Naoki

parent
50a81b1783
commit
cba486e043
@ -27,6 +27,7 @@ apilevel = "2.0"
|
|||||||
paramstyle = "format"
|
paramstyle = "format"
|
||||||
|
|
||||||
from _mysql import *
|
from _mysql import *
|
||||||
|
from MySQLdb.compat import PY2
|
||||||
from MySQLdb.constants import FIELD_TYPE
|
from MySQLdb.constants import FIELD_TYPE
|
||||||
from MySQLdb.times import Date, Time, Timestamp, \
|
from MySQLdb.times import Date, Time, Timestamp, \
|
||||||
DateFromTicks, TimeFromTicks, TimestampFromTicks
|
DateFromTicks, TimeFromTicks, TimestampFromTicks
|
||||||
@ -72,6 +73,10 @@ def test_DBAPISet_set_equality_membership():
|
|||||||
def test_DBAPISet_set_inequality_membership():
|
def test_DBAPISet_set_inequality_membership():
|
||||||
assert FIELD_TYPE.DATE != STRING
|
assert FIELD_TYPE.DATE != STRING
|
||||||
|
|
||||||
|
if PY2:
|
||||||
|
def Binary(x):
|
||||||
|
return bytearray(x)
|
||||||
|
else:
|
||||||
def Binary(x):
|
def Binary(x):
|
||||||
return bytes(x)
|
return bytes(x)
|
||||||
|
|
||||||
|
@ -137,6 +137,10 @@ class Connection(_mysql.connection):
|
|||||||
If True, autocommit is enabled.
|
If True, autocommit is enabled.
|
||||||
If None, autocommit isn't set and server default is used.
|
If None, autocommit isn't set and server default is used.
|
||||||
|
|
||||||
|
:param bool binary_prefix:
|
||||||
|
If set, the '_binary' prefix will be used for raw byte query
|
||||||
|
arguments (e.g. Binary). This is disabled by default.
|
||||||
|
|
||||||
There are a number of undocumented, non-standard methods. See the
|
There are a number of undocumented, non-standard methods. See the
|
||||||
documentation for the MySQL C API for some hints on what they do.
|
documentation for the MySQL C API for some hints on what they do.
|
||||||
"""
|
"""
|
||||||
@ -174,6 +178,7 @@ class Connection(_mysql.connection):
|
|||||||
|
|
||||||
use_unicode = kwargs2.pop('use_unicode', use_unicode)
|
use_unicode = kwargs2.pop('use_unicode', use_unicode)
|
||||||
sql_mode = kwargs2.pop('sql_mode', '')
|
sql_mode = kwargs2.pop('sql_mode', '')
|
||||||
|
binary_prefix = kwargs2.pop('binary_prefix', False)
|
||||||
|
|
||||||
client_flag = kwargs.get('client_flag', 0)
|
client_flag = kwargs.get('client_flag', 0)
|
||||||
client_version = tuple([ numeric_part(n) for n in _mysql.get_client_info().split('.')[:2] ])
|
client_version = tuple([ numeric_part(n) for n in _mysql.get_client_info().split('.')[:2] ])
|
||||||
@ -197,7 +202,7 @@ class Connection(_mysql.connection):
|
|||||||
|
|
||||||
db = proxy(self)
|
db = proxy(self)
|
||||||
def _get_string_literal():
|
def _get_string_literal():
|
||||||
# Note: string_literal() is called for bytes object on Python 3.
|
# Note: string_literal() is called for bytes object on Python 3 (via bytes_literal)
|
||||||
def string_literal(obj, dummy=None):
|
def string_literal(obj, dummy=None):
|
||||||
return db.string_literal(obj)
|
return db.string_literal(obj)
|
||||||
return string_literal
|
return string_literal
|
||||||
@ -206,13 +211,18 @@ class Connection(_mysql.connection):
|
|||||||
if PY2:
|
if PY2:
|
||||||
# unicode_literal is called for only unicode object.
|
# unicode_literal is called for only unicode object.
|
||||||
def unicode_literal(u, dummy=None):
|
def unicode_literal(u, dummy=None):
|
||||||
return db.literal(u.encode(unicode_literal.charset))
|
return db.string_literal(u.encode(unicode_literal.charset))
|
||||||
else:
|
else:
|
||||||
# unicode_literal() is called for arbitrary object.
|
# unicode_literal() is called for arbitrary object.
|
||||||
def unicode_literal(u, dummy=None):
|
def unicode_literal(u, dummy=None):
|
||||||
return db.literal(str(u).encode(unicode_literal.charset))
|
return db.string_literal(str(u).encode(unicode_literal.charset))
|
||||||
return unicode_literal
|
return unicode_literal
|
||||||
|
|
||||||
|
def _get_bytes_literal():
|
||||||
|
def bytes_literal(obj, dummy=None):
|
||||||
|
return b'_binary' + db.string_literal(obj)
|
||||||
|
return bytes_literal
|
||||||
|
|
||||||
def _get_string_decoder():
|
def _get_string_decoder():
|
||||||
def string_decoder(s):
|
def string_decoder(s):
|
||||||
return s.decode(string_decoder.charset)
|
return s.decode(string_decoder.charset)
|
||||||
@ -220,6 +230,7 @@ class Connection(_mysql.connection):
|
|||||||
|
|
||||||
string_literal = _get_string_literal()
|
string_literal = _get_string_literal()
|
||||||
self.unicode_literal = unicode_literal = _get_unicode_literal()
|
self.unicode_literal = unicode_literal = _get_unicode_literal()
|
||||||
|
bytes_literal = _get_bytes_literal()
|
||||||
self.string_decoder = string_decoder = _get_string_decoder()
|
self.string_decoder = string_decoder = _get_string_decoder()
|
||||||
if not charset:
|
if not charset:
|
||||||
charset = self.character_set_name()
|
charset = self.character_set_name()
|
||||||
@ -234,7 +245,12 @@ class Connection(_mysql.connection):
|
|||||||
self.converter[FIELD_TYPE.VARCHAR].append((None, string_decoder))
|
self.converter[FIELD_TYPE.VARCHAR].append((None, string_decoder))
|
||||||
self.converter[FIELD_TYPE.BLOB].append((None, string_decoder))
|
self.converter[FIELD_TYPE.BLOB].append((None, string_decoder))
|
||||||
|
|
||||||
|
if binary_prefix:
|
||||||
|
self.encoders[bytes] = string_literal if PY2 else bytes_literal
|
||||||
|
self.encoders[bytearray] = bytes_literal
|
||||||
|
else:
|
||||||
self.encoders[bytes] = string_literal
|
self.encoders[bytes] = string_literal
|
||||||
|
|
||||||
self.encoders[unicode] = unicode_literal
|
self.encoders[unicode] = unicode_literal
|
||||||
self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS
|
self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS
|
||||||
if self._transactional:
|
if self._transactional:
|
||||||
|
@ -2,10 +2,12 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import capabilities
|
import capabilities
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from contextlib import closing
|
||||||
import unittest
|
import unittest
|
||||||
import MySQLdb
|
import MySQLdb
|
||||||
from MySQLdb.compat import unicode
|
from MySQLdb.compat import unicode
|
||||||
from MySQLdb import cursors
|
from MySQLdb import cursors
|
||||||
|
from configdb import connection_factory
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
@ -180,6 +182,23 @@ class test_MySQLdb(capabilities.DatabaseTest):
|
|||||||
finally:
|
finally:
|
||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
|
def test_binary_prefix(self):
|
||||||
|
# verify prefix behaviour when enabled, disabled and for default (disabled)
|
||||||
|
for binary_prefix in (True, False, None):
|
||||||
|
kwargs = self.connect_kwargs.copy()
|
||||||
|
# needs to be set to can guarantee CHARSET response for normal strings
|
||||||
|
kwargs['charset'] = 'utf8'
|
||||||
|
if binary_prefix != None:
|
||||||
|
kwargs['binary_prefix'] = binary_prefix
|
||||||
|
|
||||||
|
with closing(connection_factory(**kwargs)) as conn:
|
||||||
|
with closing(conn.cursor()) as c:
|
||||||
|
c.execute('SELECT CHARSET(%s)', (MySQLdb.Binary(b'raw bytes'),))
|
||||||
|
self.assertEqual(c.fetchall()[0][0], 'binary' if binary_prefix else 'utf8')
|
||||||
|
# normal strings should not get prefix
|
||||||
|
c.execute('SELECT CHARSET(%s)', ('str',))
|
||||||
|
self.assertEqual(c.fetchall()[0][0], 'utf8')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if test_MySQLdb.leak_test:
|
if test_MySQLdb.leak_test:
|
||||||
|
Reference in New Issue
Block a user