diff --git a/mysql/MySQLdb.py b/mysql/MySQLdb.py index af2b626..1ee1432 100644 --- a/mysql/MySQLdb.py +++ b/mysql/MySQLdb.py @@ -32,8 +32,9 @@ try: except ImportError: _threading = None -def Long2Int(l): return str(l)[:-1] # drop the trailing L +def Long2Int(l): s = str(l); return s[-1] == 'L' and s[:-1] or s def None2NULL(d): return "NULL" +def Thing2Literal(o): return string_literal(str(o)) # MySQL-3.23.xx now has a new escape_string function that uses # the connection to determine what character set is in use and @@ -140,12 +141,6 @@ def Binary(x): return str(x) insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE) -def escape_dict(d, qc): - d2 = {} - for k,v in d.items(): - d2[k] = qc.get(type(v), String2Literal)(v) - return d2 - def _fetchall(result, *args): rows = r = list(apply(result.fetch_row, args)) while 1: diff --git a/mysql/_mysqlmodule.c b/mysql/_mysqlmodule.c index bc8e58c..e7372d6 100644 --- a/mysql/_mysqlmodule.c +++ b/mysql/_mysqlmodule.c @@ -586,15 +586,18 @@ _mysql_string_literal( _mysql_ConnectionObject *self, PyObject *args) { - PyObject *str; + PyObject *str, *s, *o; char *in, *out; int len, size; - if (!PyArg_ParseTuple(args, "s#:string_literal", &in, &size)) return NULL; + if (!PyArg_ParseTuple(args, "O:string_literal", &o)) return NULL; + s = PyObject_Str(o); + in = PyString_AsString(s); + size = PyString_GET_SIZE(s); str = PyString_FromStringAndSize((char *) NULL, size*2+3); if (!str) return PyErr_NoMemory(); out = PyString_AS_STRING(str); #if MYSQL_VERSION_ID < 32321 - len = mysql_escape_string(out+1, in, size); + len = mysql_escape_string(out+1, s, size); #else if (self) len = mysql_real_escape_string(&(self->connection), out+1, in, size); @@ -603,18 +606,47 @@ _mysql_string_literal( #endif *out = *(out+len+1) = '\''; if (_PyString_Resize(&str, len+2) < 0) return NULL; + Py_DECREF(s); return (str); } static PyObject *_mysql_NULL; +static PyObject * +_escape_item( + PyObject *item, + PyObject *d) +{ + PyObject *quoted, *itemtype, *itemconv; + int i, n; + if (!(itemtype = PyObject_Type(item))) + goto error; + itemconv = PyObject_GetItem(d, itemtype); + Py_DECREF(itemtype); + if (!itemconv) { + PyErr_Clear(); + itemconv = PyObject_GetItem(d, + (PyObject *) &PyString_Type); + } + if (!itemconv) { + PyErr_SetString(PyExc_TypeError, + "no default type converter defined"); + goto error; + } + quoted = PyObject_CallFunction(itemconv, "O", item); + Py_DECREF(itemconv); + if (!quoted) goto error; + return quoted; + error: + return NULL; +} + static PyObject * _mysql_escape_row( PyObject *self, PyObject *args) { - PyObject *o=NULL, *d=NULL, *r=NULL, *item, *quoted, - *itemtype, *itemconv; + PyObject *o=NULL, *d=NULL, *r=NULL, *item, *quoted; int i, n; if (!PyArg_ParseTuple(args, "O!O!:escape_row", &PyTuple_Type, &o, &PyDict_Type, &d)) @@ -623,22 +655,7 @@ _mysql_escape_row( if (!(r = PyTuple_New(n))) goto error; for (i=0; i