Use _binary prefix for bytes/bytearray parameters (#140)

- Based on #106 but now disabled by default
- Can be enabled via 'binary_prefix' connection parameter
- Added unit tests to verify behaviour
This commit is contained in:
Vilnis Termanis
2017-02-10 11:36:41 +00:00
committed by INADA Naoki
parent 50a81b1783
commit cba486e043
3 changed files with 46 additions and 6 deletions

View File

@ -27,6 +27,7 @@ apilevel = "2.0"
paramstyle = "format" paramstyle = "format"
from _mysql import * from _mysql import *
from MySQLdb.compat import PY2
from MySQLdb.constants import FIELD_TYPE from MySQLdb.constants import FIELD_TYPE
from MySQLdb.times import Date, Time, Timestamp, \ from MySQLdb.times import Date, Time, Timestamp, \
DateFromTicks, TimeFromTicks, TimestampFromTicks DateFromTicks, TimeFromTicks, TimestampFromTicks
@ -72,6 +73,10 @@ def test_DBAPISet_set_equality_membership():
def test_DBAPISet_set_inequality_membership(): def test_DBAPISet_set_inequality_membership():
assert FIELD_TYPE.DATE != STRING assert FIELD_TYPE.DATE != STRING
if PY2:
def Binary(x):
return bytearray(x)
else:
def Binary(x): def Binary(x):
return bytes(x) return bytes(x)

View File

@ -137,6 +137,10 @@ class Connection(_mysql.connection):
If True, autocommit is enabled. If True, autocommit is enabled.
If None, autocommit isn't set and server default is used. If None, autocommit isn't set and server default is used.
:param bool binary_prefix:
If set, the '_binary' prefix will be used for raw byte query
arguments (e.g. Binary). This is disabled by default.
There are a number of undocumented, non-standard methods. See the There are a number of undocumented, non-standard methods. See the
documentation for the MySQL C API for some hints on what they do. documentation for the MySQL C API for some hints on what they do.
""" """
@ -174,6 +178,7 @@ class Connection(_mysql.connection):
use_unicode = kwargs2.pop('use_unicode', use_unicode) use_unicode = kwargs2.pop('use_unicode', use_unicode)
sql_mode = kwargs2.pop('sql_mode', '') sql_mode = kwargs2.pop('sql_mode', '')
binary_prefix = kwargs2.pop('binary_prefix', False)
client_flag = kwargs.get('client_flag', 0) client_flag = kwargs.get('client_flag', 0)
client_version = tuple([ numeric_part(n) for n in _mysql.get_client_info().split('.')[:2] ]) client_version = tuple([ numeric_part(n) for n in _mysql.get_client_info().split('.')[:2] ])
@ -197,7 +202,7 @@ class Connection(_mysql.connection):
db = proxy(self) db = proxy(self)
def _get_string_literal(): def _get_string_literal():
# Note: string_literal() is called for bytes object on Python 3. # Note: string_literal() is called for bytes object on Python 3 (via bytes_literal)
def string_literal(obj, dummy=None): def string_literal(obj, dummy=None):
return db.string_literal(obj) return db.string_literal(obj)
return string_literal return string_literal
@ -206,13 +211,18 @@ class Connection(_mysql.connection):
if PY2: if PY2:
# unicode_literal is called for only unicode object. # unicode_literal is called for only unicode object.
def unicode_literal(u, dummy=None): def unicode_literal(u, dummy=None):
return db.literal(u.encode(unicode_literal.charset)) return db.string_literal(u.encode(unicode_literal.charset))
else: else:
# unicode_literal() is called for arbitrary object. # unicode_literal() is called for arbitrary object.
def unicode_literal(u, dummy=None): def unicode_literal(u, dummy=None):
return db.literal(str(u).encode(unicode_literal.charset)) return db.string_literal(str(u).encode(unicode_literal.charset))
return unicode_literal return unicode_literal
def _get_bytes_literal():
def bytes_literal(obj, dummy=None):
return b'_binary' + db.string_literal(obj)
return bytes_literal
def _get_string_decoder(): def _get_string_decoder():
def string_decoder(s): def string_decoder(s):
return s.decode(string_decoder.charset) return s.decode(string_decoder.charset)
@ -220,6 +230,7 @@ class Connection(_mysql.connection):
string_literal = _get_string_literal() string_literal = _get_string_literal()
self.unicode_literal = unicode_literal = _get_unicode_literal() self.unicode_literal = unicode_literal = _get_unicode_literal()
bytes_literal = _get_bytes_literal()
self.string_decoder = string_decoder = _get_string_decoder() self.string_decoder = string_decoder = _get_string_decoder()
if not charset: if not charset:
charset = self.character_set_name() charset = self.character_set_name()
@ -234,7 +245,12 @@ class Connection(_mysql.connection):
self.converter[FIELD_TYPE.VARCHAR].append((None, string_decoder)) self.converter[FIELD_TYPE.VARCHAR].append((None, string_decoder))
self.converter[FIELD_TYPE.BLOB].append((None, string_decoder)) self.converter[FIELD_TYPE.BLOB].append((None, string_decoder))
if binary_prefix:
self.encoders[bytes] = string_literal if PY2 else bytes_literal
self.encoders[bytearray] = bytes_literal
else:
self.encoders[bytes] = string_literal self.encoders[bytes] = string_literal
self.encoders[unicode] = unicode_literal self.encoders[unicode] = unicode_literal
self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS
if self._transactional: if self._transactional:

View File

@ -2,10 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import capabilities import capabilities
from datetime import timedelta from datetime import timedelta
from contextlib import closing
import unittest import unittest
import MySQLdb import MySQLdb
from MySQLdb.compat import unicode from MySQLdb.compat import unicode
from MySQLdb import cursors from MySQLdb import cursors
from configdb import connection_factory
import warnings import warnings
@ -180,6 +182,23 @@ class test_MySQLdb(capabilities.DatabaseTest):
finally: finally:
c.close() c.close()
def test_binary_prefix(self):
# verify prefix behaviour when enabled, disabled and for default (disabled)
for binary_prefix in (True, False, None):
kwargs = self.connect_kwargs.copy()
# needs to be set to can guarantee CHARSET response for normal strings
kwargs['charset'] = 'utf8'
if binary_prefix != None:
kwargs['binary_prefix'] = binary_prefix
with closing(connection_factory(**kwargs)) as conn:
with closing(conn.cursor()) as c:
c.execute('SELECT CHARSET(%s)', (MySQLdb.Binary(b'raw bytes'),))
self.assertEqual(c.fetchall()[0][0], 'binary' if binary_prefix else 'utf8')
# normal strings should not get prefix
c.execute('SELECT CHARSET(%s)', ('str',))
self.assertEqual(c.fetchall()[0][0], 'utf8')
if __name__ == '__main__': if __name__ == '__main__':
if test_MySQLdb.leak_test: if test_MySQLdb.leak_test: