Fix executemany with binary prefix (#605)

Fix #494
This commit is contained in:
Inada Naoki
2023-05-18 17:52:38 +09:00
committed by GitHub
parent 3d6b8c9b7c
commit 62f0645376
3 changed files with 27 additions and 35 deletions

View File

@ -110,34 +110,6 @@ class BaseCursor:
del exc_info del exc_info
self.close() self.close()
def _escape_args(self, args, conn):
encoding = conn.encoding
literal = conn.literal
def ensure_bytes(x):
if isinstance(x, str):
return x.encode(encoding)
elif isinstance(x, tuple):
return tuple(map(ensure_bytes, x))
elif isinstance(x, list):
return list(map(ensure_bytes, x))
return x
if isinstance(args, (tuple, list)):
ret = tuple(literal(ensure_bytes(arg)) for arg in args)
elif isinstance(args, dict):
ret = {
ensure_bytes(key): literal(ensure_bytes(val))
for (key, val) in args.items()
}
else:
# If it's not a dictionary let's try escaping it anyways.
# Worst case it will throw a Value error
ret = literal(ensure_bytes(args))
ensure_bytes = None # break circular reference
return ret
def _check_executed(self): def _check_executed(self):
if not self._executed: if not self._executed:
raise ProgrammingError("execute() first") raise ProgrammingError("execute() first")
@ -279,8 +251,6 @@ class BaseCursor:
def _do_execute_many( def _do_execute_many(
self, prefix, values, postfix, args, max_stmt_length, encoding self, prefix, values, postfix, args, max_stmt_length, encoding
): ):
conn = self._get_db()
escape = self._escape_args
if isinstance(prefix, str): if isinstance(prefix, str):
prefix = prefix.encode(encoding) prefix = prefix.encode(encoding)
if isinstance(values, str): if isinstance(values, str):
@ -289,11 +259,11 @@ class BaseCursor:
postfix = postfix.encode(encoding) postfix = postfix.encode(encoding)
sql = bytearray(prefix) sql = bytearray(prefix)
args = iter(args) args = iter(args)
v = values % escape(next(args), conn) v = self._mogrify(values, next(args))
sql += v sql += v
rows = 0 rows = 0
for arg in args: for arg in args:
v = values % escape(arg, conn) v = self._mogrify(values, arg)
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
rows += self.execute(sql + postfix) rows += self.execute(sql + postfix)
sql = bytearray(prefix) sql = bytearray(prefix)

View File

@ -2,9 +2,10 @@
# http://dev.mysql.com/doc/refman/5.1/en/option-files.html # http://dev.mysql.com/doc/refman/5.1/en/option-files.html
# and set TESTDB in your environment to the name of the file # and set TESTDB in your environment to the name of the file
# $ docker run -e MYSQL_ALLOW_EMPTY_PASSWORD=yes -p 3306:3306 --rm --name mysqld mysql:latest
[MySQLdb-tests] [MySQLdb-tests]
host = 127.0.0.1 host = 127.0.0.1
user = test user = root
database = test database = test
#password = #password =
default-character-set = utf8 default-character-set = utf8mb4

View File

@ -72,7 +72,7 @@ def test_executemany():
# values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9) # values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)
# """ # """
# list args # list args
data = range(10) data = [(i,) for i in range(10)]
cursor.executemany("insert into test (data) values (%s)", data) cursor.executemany("insert into test (data) values (%s)", data)
assert cursor._executed.endswith( assert cursor._executed.endswith(
b",(7),(8),(9)" b",(7),(8),(9)"
@ -222,3 +222,24 @@ SELECT * FROM test_cursor_discard_result WHERE id BETWEEN 21 AND 30;
"SELECT * FROM test_cursor_discard_result WHERE id BETWEEN 31 AND 40" "SELECT * FROM test_cursor_discard_result WHERE id BETWEEN 31 AND 40"
) )
assert cursor.fetchone() == (31, "row 31") assert cursor.fetchone() == (31, "row 31")
def test_binary_prefix():
# https://github.com/PyMySQL/mysqlclient/issues/494
conn = connect(binary_prefix=True)
cursor = conn.cursor()
cursor.execute("DROP TABLE IF EXISTS test_binary_prefix")
cursor.execute(
"""\
CREATE TABLE test_binary_prefix (
id INTEGER NOT NULL AUTO_INCREMENT,
json JSON NOT NULL,
PRIMARY KEY (id)
) CHARSET=utf8mb4"""
)
cursor.executemany(
"INSERT INTO test_binary_prefix (id, json) VALUES (%(id)s, %(json)s)",
({"id": 1, "json": "{}"}, {"id": 2, "json": "{}"}),
)