Fix Connection.escape() with Unicode input (#608)

After aed1dd2, Connection.escape() used ASCII to escape Unicode input.
This commit makes it uses connection encoding instead.
This commit is contained in:
Inada Naoki
2023-05-18 20:08:04 +09:00
committed by GitHub
parent 44d0f7a148
commit b162dddcf3

View File

@ -943,7 +943,7 @@ _mysql_escape_string(
{
PyObject *str;
char *in, *out;
int len;
unsigned long len;
Py_ssize_t size;
if (!PyArg_ParseTuple(args, "s#:escape_string", &in, &size)) return NULL;
str = PyBytes_FromStringAndSize((char *) NULL, size*2+1);
@ -980,10 +980,7 @@ _mysql_string_literal(
_mysql_ConnectionObject *self,
PyObject *o)
{
PyObject *str, *s;
char *in, *out;
unsigned long len;
Py_ssize_t size;
PyObject *s; // input string or bytes. need to decref.
if (self && PyModule_Check((PyObject*)self))
self = NULL;
@ -991,24 +988,44 @@ _mysql_string_literal(
if (PyBytes_Check(o)) {
s = o;
Py_INCREF(s);
} else {
s = PyObject_Str(o);
if (!s) return NULL;
{
PyObject *t = PyUnicode_AsASCIIString(s);
Py_DECREF(s);
if (!t) return NULL;
}
else {
PyObject *t = PyObject_Str(o);
if (!t) return NULL;
const char *encoding = (self && self->open) ?
_get_encoding(&self->connection) : utf8;
if (encoding == utf8) {
s = t;
}
else {
s = PyUnicode_AsEncodedString(t, encoding, "strict");
Py_DECREF(t);
if (!s) return NULL;
}
}
in = PyBytes_AsString(s);
size = PyBytes_GET_SIZE(s);
str = PyBytes_FromStringAndSize((char *) NULL, size*2+3);
// Prepare input string (in, size)
const char *in;
Py_ssize_t size;
if (PyUnicode_Check(s)) {
in = PyUnicode_AsUTF8AndSize(s, &size);
} else {
assert(PyBytes_Check(s));
in = PyBytes_AsString(s);
size = PyBytes_GET_SIZE(s);
}
// Prepare output buffer (str, out)
PyObject *str = PyBytes_FromStringAndSize((char *) NULL, size*2+3);
if (!str) {
Py_DECREF(s);
return PyErr_NoMemory();
}
out = PyBytes_AS_STRING(str);
char *out = PyBytes_AS_STRING(str);
// escape
unsigned long len;
if (self && self->open) {
#if MYSQL_VERSION_ID >= 50707 && !defined(MARIADB_BASE_VERSION) && !defined(MARIADB_VERSION_ID)
len = mysql_real_escape_string_quote(&(self->connection), out+1, in, size, '\'');
@ -1018,10 +1035,14 @@ _mysql_string_literal(
} else {
len = mysql_escape_string(out+1, in, size);
}
*out = *(out+len+1) = '\'';
if (_PyBytes_Resize(&str, len+2) < 0) return NULL;
Py_DECREF(s);
return (str);
*out = *(out+len+1) = '\'';
if (_PyBytes_Resize(&str, len+2) < 0) {
Py_DECREF(str);
return NULL;
}
return str;
}
static PyObject *
@ -1499,8 +1520,9 @@ _mysql_ResultObject_discard(
// do nothing
}
Py_END_ALLOW_THREADS
if (mysql_errno(self->conn)) {
return _mysql_Exception(self->conn);
_mysql_ConnectionObject *conn = (_mysql_ConnectionObject *)self->conn;
if (mysql_errno(&conn->connection)) {
return _mysql_Exception(conn);
}
Py_RETURN_NONE;
}