mirror of
https://github.com/PyMySQL/mysqlclient.git
synced 2025-08-16 12:27:03 +08:00
Add a dictionary that controls conversion of Python types into
MySQL literals.
This commit is contained in:
@ -18,12 +18,22 @@ __version__ = """$Revision$"""[11:-2]
|
||||
|
||||
from _mysql import *
|
||||
from time import localtime
|
||||
import re
|
||||
import re, types
|
||||
|
||||
threadsafety = 1
|
||||
apilevel = "2.0"
|
||||
paramstyle = "format"
|
||||
|
||||
def Long2Int(l): return str(l)[:-1] # drop the trailing L
|
||||
def None2NULL(d): return "NULL"
|
||||
def String2literal(s): return "'%s'" % escape_string(str(s))
|
||||
|
||||
quote_conv = { types.IntType: str,
|
||||
types.LongType: Long2Int,
|
||||
types.FloatType: str,
|
||||
types.NoneType: None2NULL,
|
||||
types.StringType: String2literal }
|
||||
|
||||
type_conv = { FIELD_TYPE.TINY: int,
|
||||
FIELD_TYPE.SHORT: int,
|
||||
FIELD_TYPE.LONG: int,
|
||||
@ -34,7 +44,7 @@ type_conv = { FIELD_TYPE.TINY: int,
|
||||
FIELD_TYPE.YEAR: int }
|
||||
|
||||
try:
|
||||
from DateTime import Date, Time, Timestamp, ISO
|
||||
from DateTime import Date, Time, Timestamp, ISO, DateTimeType
|
||||
|
||||
def DateFromTicks(ticks):
|
||||
return apply(Date, localtime(ticks)[:3])
|
||||
@ -59,6 +69,10 @@ try:
|
||||
#type_conv[FIELD_TYPE.TIME] = ISO.ParseTime
|
||||
type_conv[FIELD_TYPE.DATE] = ISO.ParseDate
|
||||
|
||||
#def DateTime2literal(d): return "'%s'" % format_TIMESTAMP(d)
|
||||
|
||||
#quote_conv[DateTimeType] = DateTime2literal
|
||||
|
||||
except ImportError:
|
||||
# no DateTime? We'll muddle through somehow.
|
||||
from time import strftime
|
||||
@ -108,11 +122,10 @@ def Binary(x): return str(x)
|
||||
|
||||
insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE)
|
||||
|
||||
def escape_dict(d):
|
||||
def escape_dict(d, qc):
|
||||
d2 = {}
|
||||
for k,v in d.items():
|
||||
if v is None: d2[k] = "NULL"
|
||||
else: d2[k] = "'%s'" % escape_string(str(v))
|
||||
d2[k] = qc.get(type(v), String2literal)(v)
|
||||
return d2
|
||||
|
||||
|
||||
@ -149,15 +162,16 @@ class _Cursor:
|
||||
args -- sequence or mapping, parameters to use with query."""
|
||||
from types import ListType, TupleType
|
||||
from string import rfind, join, split, atoi
|
||||
qc = self.connection.quote_conv
|
||||
if not args:
|
||||
self._query(query)
|
||||
elif type(args) is ListType and type(args[0]) is TupleType:
|
||||
self.executemany(query, args) # deprecated
|
||||
else:
|
||||
try:
|
||||
self._query(query % escape_row(args))
|
||||
self._query(query % escape_row(args, qc))
|
||||
except TypeError:
|
||||
self._query(query % escape_dict(args))
|
||||
self._query(query % escape_dict(args, qc))
|
||||
|
||||
def executemany(self, query, args):
|
||||
"""cursor.executemany(self, query, args)
|
||||
@ -174,13 +188,14 @@ class _Cursor:
|
||||
if not m: raise ProgrammingError, "can't find values"
|
||||
p = m.start(1)
|
||||
escape = escape_row
|
||||
qc = self.connection.quote_conv
|
||||
try:
|
||||
q = [query % escape(args[0])]
|
||||
q = [query % escape(args[0], qc)]
|
||||
except TypeError:
|
||||
escape = escape_dict
|
||||
q = [query % escape(args[0])]
|
||||
q = [query % escape(args[0], qc)]
|
||||
qv = query[p:]
|
||||
for a in args[1:]: q.append(qv % escape(a))
|
||||
for a in args[1:]: q.append(qv % escape(a, qc))
|
||||
self._query(join(q, ',\n'))
|
||||
|
||||
def _query(self, q):
|
||||
@ -252,6 +267,7 @@ class Connection:
|
||||
def __init__(self, **kwargs):
|
||||
from _mysql import connect
|
||||
if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy()
|
||||
self.quote_conv = kwargs.get('quote_conv', quote_conv.copy())
|
||||
self.db = apply(connect, (), kwargs)
|
||||
|
||||
def close(self):
|
||||
|
@ -548,49 +548,48 @@ _mysql_escape_row(self, args)
|
||||
PyObject *self;
|
||||
PyObject *args;
|
||||
{
|
||||
PyObject *o=NULL, *r=NULL, *item, *quoted, *str, *itemstr;
|
||||
PyObject *o=NULL, *d=NULL, *r=NULL, *item, *quoted, *str, *itemstr,
|
||||
*itemtype, *itemconv;
|
||||
char *in, *out;
|
||||
int i, n, len, size;
|
||||
if (!PyArg_ParseTuple(args, "O:escape_row", &o)) goto error2;
|
||||
if (!PyArg_ParseTuple(args, "OO:escape_row", &o, &d)) goto error;
|
||||
if (!PySequence_Check(o)) {
|
||||
PyErr_SetString(PyExc_TypeError, "sequence required");
|
||||
goto error2;
|
||||
goto error;
|
||||
}
|
||||
if (!(n = PyObject_Length(o))) goto error2;
|
||||
if (!PyMapping_Check(d)) {
|
||||
PyErr_SetString(PyExc_TypeError, "mapping required");
|
||||
goto error;
|
||||
}
|
||||
if (!(n = PyObject_Length(o))) goto error;
|
||||
if (!(r = PyTuple_New(n))) goto error;
|
||||
for (i=0; i<n; i++) {
|
||||
if (!(item = PySequence_GetItem(o, i))) goto error;
|
||||
if (item == Py_None) {
|
||||
quoted = _mysql_NULL;
|
||||
Py_INCREF(_mysql_NULL);
|
||||
if (!(itemtype = PyObject_Type(item)))
|
||||
goto error;
|
||||
itemconv = PyObject_GetItem(d, itemtype);
|
||||
Py_DECREF(itemtype);
|
||||
if (!itemconv) {
|
||||
PyErr_Clear();
|
||||
itemconv = PyObject_GetItem(d,
|
||||
(PyObject *) &PyString_Type);
|
||||
}
|
||||
else {
|
||||
if (!(itemstr = PyObject_Str(item)))
|
||||
goto error;
|
||||
if (!(in = PyString_AsString(itemstr))) {
|
||||
Py_DECREF(itemstr);
|
||||
goto error;
|
||||
}
|
||||
size = PyString_GET_SIZE(itemstr);
|
||||
str = PyString_FromStringAndSize((char *)NULL, size*2+3);
|
||||
if (!str) goto error;
|
||||
out = PyString_AS_STRING(str);
|
||||
len = mysql_escape_string(out+1, in, size);
|
||||
*out = '\'' ;
|
||||
*(out+len+1) = '\'' ;
|
||||
*(out+len+2) = 0;
|
||||
if (_PyString_Resize(&str, len+2) < 0)
|
||||
goto error;
|
||||
Py_DECREF(itemstr);
|
||||
quoted = str;
|
||||
if (!itemconv) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"no default type converter defined");
|
||||
goto error;
|
||||
}
|
||||
quoted = PyObject_CallFunction(itemconv, "O", item);
|
||||
Py_DECREF(itemconv);
|
||||
if (!quoted) goto error;
|
||||
Py_DECREF(item);
|
||||
PyTuple_SET_ITEM(r, i, quoted);
|
||||
}
|
||||
return r;
|
||||
error:
|
||||
Py_XDECREF(r);
|
||||
error2:
|
||||
Py_XDECREF(o);
|
||||
Py_XDECREF(d);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@ -1347,7 +1346,12 @@ all converted reasonably, except DECIMAL.\n\
|
||||
\n\
|
||||
result.describe() produces a DB API description of the rows.\n\
|
||||
\n\
|
||||
escape_row() accepts a sequence of items, converts them to strings, does\n\
|
||||
escape_row() accepts a sequence of items and a type conversion dictionary.\n\
|
||||
Using the type of the item, it gets a converter function from the dictionary\n\
|
||||
(uses the string type if the item type is not found) and applies this to the\n\
|
||||
item. the result should be converted to strings with all the necessary\n\
|
||||
quoting.\n\
|
||||
\n\
|
||||
mysql_escape_string() on them, and returns them as a tuple.\n\
|
||||
\n\
|
||||
result.field_flags() returns the field flags for the result.\n\
|
||||
|
Reference in New Issue
Block a user