mirror of
https://github.com/PyMySQL/mysqlclient.git
synced 2025-08-15 19:31:54 +08:00
Add a dictionary that controls conversion of Python types into
MySQL literals.
This commit is contained in:
@ -18,12 +18,22 @@ __version__ = """$Revision$"""[11:-2]
|
|||||||
|
|
||||||
from _mysql import *
|
from _mysql import *
|
||||||
from time import localtime
|
from time import localtime
|
||||||
import re
|
import re, types
|
||||||
|
|
||||||
threadsafety = 1
|
threadsafety = 1
|
||||||
apilevel = "2.0"
|
apilevel = "2.0"
|
||||||
paramstyle = "format"
|
paramstyle = "format"
|
||||||
|
|
||||||
|
def Long2Int(l): return str(l)[:-1] # drop the trailing L
|
||||||
|
def None2NULL(d): return "NULL"
|
||||||
|
def String2literal(s): return "'%s'" % escape_string(str(s))
|
||||||
|
|
||||||
|
quote_conv = { types.IntType: str,
|
||||||
|
types.LongType: Long2Int,
|
||||||
|
types.FloatType: str,
|
||||||
|
types.NoneType: None2NULL,
|
||||||
|
types.StringType: String2literal }
|
||||||
|
|
||||||
type_conv = { FIELD_TYPE.TINY: int,
|
type_conv = { FIELD_TYPE.TINY: int,
|
||||||
FIELD_TYPE.SHORT: int,
|
FIELD_TYPE.SHORT: int,
|
||||||
FIELD_TYPE.LONG: int,
|
FIELD_TYPE.LONG: int,
|
||||||
@ -34,7 +44,7 @@ type_conv = { FIELD_TYPE.TINY: int,
|
|||||||
FIELD_TYPE.YEAR: int }
|
FIELD_TYPE.YEAR: int }
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from DateTime import Date, Time, Timestamp, ISO
|
from DateTime import Date, Time, Timestamp, ISO, DateTimeType
|
||||||
|
|
||||||
def DateFromTicks(ticks):
|
def DateFromTicks(ticks):
|
||||||
return apply(Date, localtime(ticks)[:3])
|
return apply(Date, localtime(ticks)[:3])
|
||||||
@ -59,6 +69,10 @@ try:
|
|||||||
#type_conv[FIELD_TYPE.TIME] = ISO.ParseTime
|
#type_conv[FIELD_TYPE.TIME] = ISO.ParseTime
|
||||||
type_conv[FIELD_TYPE.DATE] = ISO.ParseDate
|
type_conv[FIELD_TYPE.DATE] = ISO.ParseDate
|
||||||
|
|
||||||
|
#def DateTime2literal(d): return "'%s'" % format_TIMESTAMP(d)
|
||||||
|
|
||||||
|
#quote_conv[DateTimeType] = DateTime2literal
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# no DateTime? We'll muddle through somehow.
|
# no DateTime? We'll muddle through somehow.
|
||||||
from time import strftime
|
from time import strftime
|
||||||
@ -108,11 +122,10 @@ def Binary(x): return str(x)
|
|||||||
|
|
||||||
insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE)
|
insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE)
|
||||||
|
|
||||||
def escape_dict(d):
|
def escape_dict(d, qc):
|
||||||
d2 = {}
|
d2 = {}
|
||||||
for k,v in d.items():
|
for k,v in d.items():
|
||||||
if v is None: d2[k] = "NULL"
|
d2[k] = qc.get(type(v), String2literal)(v)
|
||||||
else: d2[k] = "'%s'" % escape_string(str(v))
|
|
||||||
return d2
|
return d2
|
||||||
|
|
||||||
|
|
||||||
@ -149,15 +162,16 @@ class _Cursor:
|
|||||||
args -- sequence or mapping, parameters to use with query."""
|
args -- sequence or mapping, parameters to use with query."""
|
||||||
from types import ListType, TupleType
|
from types import ListType, TupleType
|
||||||
from string import rfind, join, split, atoi
|
from string import rfind, join, split, atoi
|
||||||
|
qc = self.connection.quote_conv
|
||||||
if not args:
|
if not args:
|
||||||
self._query(query)
|
self._query(query)
|
||||||
elif type(args) is ListType and type(args[0]) is TupleType:
|
elif type(args) is ListType and type(args[0]) is TupleType:
|
||||||
self.executemany(query, args) # deprecated
|
self.executemany(query, args) # deprecated
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self._query(query % escape_row(args))
|
self._query(query % escape_row(args, qc))
|
||||||
except TypeError:
|
except TypeError:
|
||||||
self._query(query % escape_dict(args))
|
self._query(query % escape_dict(args, qc))
|
||||||
|
|
||||||
def executemany(self, query, args):
|
def executemany(self, query, args):
|
||||||
"""cursor.executemany(self, query, args)
|
"""cursor.executemany(self, query, args)
|
||||||
@ -174,13 +188,14 @@ class _Cursor:
|
|||||||
if not m: raise ProgrammingError, "can't find values"
|
if not m: raise ProgrammingError, "can't find values"
|
||||||
p = m.start(1)
|
p = m.start(1)
|
||||||
escape = escape_row
|
escape = escape_row
|
||||||
|
qc = self.connection.quote_conv
|
||||||
try:
|
try:
|
||||||
q = [query % escape(args[0])]
|
q = [query % escape(args[0], qc)]
|
||||||
except TypeError:
|
except TypeError:
|
||||||
escape = escape_dict
|
escape = escape_dict
|
||||||
q = [query % escape(args[0])]
|
q = [query % escape(args[0], qc)]
|
||||||
qv = query[p:]
|
qv = query[p:]
|
||||||
for a in args[1:]: q.append(qv % escape(a))
|
for a in args[1:]: q.append(qv % escape(a, qc))
|
||||||
self._query(join(q, ',\n'))
|
self._query(join(q, ',\n'))
|
||||||
|
|
||||||
def _query(self, q):
|
def _query(self, q):
|
||||||
@ -252,6 +267,7 @@ class Connection:
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
from _mysql import connect
|
from _mysql import connect
|
||||||
if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy()
|
if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy()
|
||||||
|
self.quote_conv = kwargs.get('quote_conv', quote_conv.copy())
|
||||||
self.db = apply(connect, (), kwargs)
|
self.db = apply(connect, (), kwargs)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -548,49 +548,48 @@ _mysql_escape_row(self, args)
|
|||||||
PyObject *self;
|
PyObject *self;
|
||||||
PyObject *args;
|
PyObject *args;
|
||||||
{
|
{
|
||||||
PyObject *o=NULL, *r=NULL, *item, *quoted, *str, *itemstr;
|
PyObject *o=NULL, *d=NULL, *r=NULL, *item, *quoted, *str, *itemstr,
|
||||||
|
*itemtype, *itemconv;
|
||||||
char *in, *out;
|
char *in, *out;
|
||||||
int i, n, len, size;
|
int i, n, len, size;
|
||||||
if (!PyArg_ParseTuple(args, "O:escape_row", &o)) goto error2;
|
if (!PyArg_ParseTuple(args, "OO:escape_row", &o, &d)) goto error;
|
||||||
if (!PySequence_Check(o)) {
|
if (!PySequence_Check(o)) {
|
||||||
PyErr_SetString(PyExc_TypeError, "sequence required");
|
PyErr_SetString(PyExc_TypeError, "sequence required");
|
||||||
goto error2;
|
goto error;
|
||||||
}
|
}
|
||||||
if (!(n = PyObject_Length(o))) goto error2;
|
if (!PyMapping_Check(d)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError, "mapping required");
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
if (!(n = PyObject_Length(o))) goto error;
|
||||||
if (!(r = PyTuple_New(n))) goto error;
|
if (!(r = PyTuple_New(n))) goto error;
|
||||||
for (i=0; i<n; i++) {
|
for (i=0; i<n; i++) {
|
||||||
if (!(item = PySequence_GetItem(o, i))) goto error;
|
if (!(item = PySequence_GetItem(o, i))) goto error;
|
||||||
if (item == Py_None) {
|
if (!(itemtype = PyObject_Type(item)))
|
||||||
quoted = _mysql_NULL;
|
|
||||||
Py_INCREF(_mysql_NULL);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (!(itemstr = PyObject_Str(item)))
|
|
||||||
goto error;
|
goto error;
|
||||||
if (!(in = PyString_AsString(itemstr))) {
|
itemconv = PyObject_GetItem(d, itemtype);
|
||||||
Py_DECREF(itemstr);
|
Py_DECREF(itemtype);
|
||||||
|
if (!itemconv) {
|
||||||
|
PyErr_Clear();
|
||||||
|
itemconv = PyObject_GetItem(d,
|
||||||
|
(PyObject *) &PyString_Type);
|
||||||
|
}
|
||||||
|
if (!itemconv) {
|
||||||
|
PyErr_SetString(PyExc_TypeError,
|
||||||
|
"no default type converter defined");
|
||||||
goto error;
|
goto error;
|
||||||
}
|
}
|
||||||
size = PyString_GET_SIZE(itemstr);
|
quoted = PyObject_CallFunction(itemconv, "O", item);
|
||||||
str = PyString_FromStringAndSize((char *)NULL, size*2+3);
|
Py_DECREF(itemconv);
|
||||||
if (!str) goto error;
|
if (!quoted) goto error;
|
||||||
out = PyString_AS_STRING(str);
|
|
||||||
len = mysql_escape_string(out+1, in, size);
|
|
||||||
*out = '\'' ;
|
|
||||||
*(out+len+1) = '\'' ;
|
|
||||||
*(out+len+2) = 0;
|
|
||||||
if (_PyString_Resize(&str, len+2) < 0)
|
|
||||||
goto error;
|
|
||||||
Py_DECREF(itemstr);
|
|
||||||
quoted = str;
|
|
||||||
}
|
|
||||||
Py_DECREF(item);
|
Py_DECREF(item);
|
||||||
PyTuple_SET_ITEM(r, i, quoted);
|
PyTuple_SET_ITEM(r, i, quoted);
|
||||||
}
|
}
|
||||||
return r;
|
return r;
|
||||||
error:
|
error:
|
||||||
Py_XDECREF(r);
|
Py_XDECREF(r);
|
||||||
error2:
|
Py_XDECREF(o);
|
||||||
|
Py_XDECREF(d);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1347,7 +1346,12 @@ all converted reasonably, except DECIMAL.\n\
|
|||||||
\n\
|
\n\
|
||||||
result.describe() produces a DB API description of the rows.\n\
|
result.describe() produces a DB API description of the rows.\n\
|
||||||
\n\
|
\n\
|
||||||
escape_row() accepts a sequence of items, converts them to strings, does\n\
|
escape_row() accepts a sequence of items and a type conversion dictionary.\n\
|
||||||
|
Using the type of the item, it gets a converter function from the dictionary\n\
|
||||||
|
(uses the string type if the item type is not found) and applies this to the\n\
|
||||||
|
item. the result should be converted to strings with all the necessary\n\
|
||||||
|
quoting.\n\
|
||||||
|
\n\
|
||||||
mysql_escape_string() on them, and returns them as a tuple.\n\
|
mysql_escape_string() on them, and returns them as a tuple.\n\
|
||||||
\n\
|
\n\
|
||||||
result.field_flags() returns the field flags for the result.\n\
|
result.field_flags() returns the field flags for the result.\n\
|
||||||
|
Reference in New Issue
Block a user