Add a dictionary that controls conversion of Python types into

MySQL literals.
This commit is contained in:
adustman
2000-01-07 00:30:44 +00:00
parent 589df8797c
commit 42bd06e986
2 changed files with 58 additions and 38 deletions

View File

@ -18,12 +18,22 @@ __version__ = """$Revision$"""[11:-2]
from _mysql import *
from time import localtime
import re
import re, types
threadsafety = 1
apilevel = "2.0"
paramstyle = "format"
def Long2Int(l): return str(l)[:-1] # drop the trailing L
def None2NULL(d): return "NULL"
def String2literal(s): return "'%s'" % escape_string(str(s))
quote_conv = { types.IntType: str,
types.LongType: Long2Int,
types.FloatType: str,
types.NoneType: None2NULL,
types.StringType: String2literal }
type_conv = { FIELD_TYPE.TINY: int,
FIELD_TYPE.SHORT: int,
FIELD_TYPE.LONG: int,
@ -34,7 +44,7 @@ type_conv = { FIELD_TYPE.TINY: int,
FIELD_TYPE.YEAR: int }
try:
from DateTime import Date, Time, Timestamp, ISO
from DateTime import Date, Time, Timestamp, ISO, DateTimeType
def DateFromTicks(ticks):
return apply(Date, localtime(ticks)[:3])
@ -59,6 +69,10 @@ try:
#type_conv[FIELD_TYPE.TIME] = ISO.ParseTime
type_conv[FIELD_TYPE.DATE] = ISO.ParseDate
#def DateTime2literal(d): return "'%s'" % format_TIMESTAMP(d)
#quote_conv[DateTimeType] = DateTime2literal
except ImportError:
# no DateTime? We'll muddle through somehow.
from time import strftime
@ -108,11 +122,10 @@ def Binary(x): return str(x)
insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE)
def escape_dict(d):
def escape_dict(d, qc):
d2 = {}
for k,v in d.items():
if v is None: d2[k] = "NULL"
else: d2[k] = "'%s'" % escape_string(str(v))
d2[k] = qc.get(type(v), String2literal)(v)
return d2
@ -149,15 +162,16 @@ class _Cursor:
args -- sequence or mapping, parameters to use with query."""
from types import ListType, TupleType
from string import rfind, join, split, atoi
qc = self.connection.quote_conv
if not args:
self._query(query)
elif type(args) is ListType and type(args[0]) is TupleType:
self.executemany(query, args) # deprecated
else:
try:
self._query(query % escape_row(args))
self._query(query % escape_row(args, qc))
except TypeError:
self._query(query % escape_dict(args))
self._query(query % escape_dict(args, qc))
def executemany(self, query, args):
"""cursor.executemany(self, query, args)
@ -174,13 +188,14 @@ class _Cursor:
if not m: raise ProgrammingError, "can't find values"
p = m.start(1)
escape = escape_row
qc = self.connection.quote_conv
try:
q = [query % escape(args[0])]
q = [query % escape(args[0], qc)]
except TypeError:
escape = escape_dict
q = [query % escape(args[0])]
q = [query % escape(args[0], qc)]
qv = query[p:]
for a in args[1:]: q.append(qv % escape(a))
for a in args[1:]: q.append(qv % escape(a, qc))
self._query(join(q, ',\n'))
def _query(self, q):
@ -252,6 +267,7 @@ class Connection:
def __init__(self, **kwargs):
from _mysql import connect
if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy()
self.quote_conv = kwargs.get('quote_conv', quote_conv.copy())
self.db = apply(connect, (), kwargs)
def close(self):

View File

@ -548,49 +548,48 @@ _mysql_escape_row(self, args)
PyObject *self;
PyObject *args;
{
PyObject *o=NULL, *r=NULL, *item, *quoted, *str, *itemstr;
PyObject *o=NULL, *d=NULL, *r=NULL, *item, *quoted, *str, *itemstr,
*itemtype, *itemconv;
char *in, *out;
int i, n, len, size;
if (!PyArg_ParseTuple(args, "O:escape_row", &o)) goto error2;
if (!PyArg_ParseTuple(args, "OO:escape_row", &o, &d)) goto error;
if (!PySequence_Check(o)) {
PyErr_SetString(PyExc_TypeError, "sequence required");
goto error2;
goto error;
}
if (!(n = PyObject_Length(o))) goto error2;
if (!PyMapping_Check(d)) {
PyErr_SetString(PyExc_TypeError, "mapping required");
goto error;
}
if (!(n = PyObject_Length(o))) goto error;
if (!(r = PyTuple_New(n))) goto error;
for (i=0; i<n; i++) {
if (!(item = PySequence_GetItem(o, i))) goto error;
if (item == Py_None) {
quoted = _mysql_NULL;
Py_INCREF(_mysql_NULL);
}
else {
if (!(itemstr = PyObject_Str(item)))
if (!(itemtype = PyObject_Type(item)))
goto error;
if (!(in = PyString_AsString(itemstr))) {
Py_DECREF(itemstr);
itemconv = PyObject_GetItem(d, itemtype);
Py_DECREF(itemtype);
if (!itemconv) {
PyErr_Clear();
itemconv = PyObject_GetItem(d,
(PyObject *) &PyString_Type);
}
if (!itemconv) {
PyErr_SetString(PyExc_TypeError,
"no default type converter defined");
goto error;
}
size = PyString_GET_SIZE(itemstr);
str = PyString_FromStringAndSize((char *)NULL, size*2+3);
if (!str) goto error;
out = PyString_AS_STRING(str);
len = mysql_escape_string(out+1, in, size);
*out = '\'' ;
*(out+len+1) = '\'' ;
*(out+len+2) = 0;
if (_PyString_Resize(&str, len+2) < 0)
goto error;
Py_DECREF(itemstr);
quoted = str;
}
quoted = PyObject_CallFunction(itemconv, "O", item);
Py_DECREF(itemconv);
if (!quoted) goto error;
Py_DECREF(item);
PyTuple_SET_ITEM(r, i, quoted);
}
return r;
error:
Py_XDECREF(r);
error2:
Py_XDECREF(o);
Py_XDECREF(d);
return NULL;
}
@ -1347,7 +1346,12 @@ all converted reasonably, except DECIMAL.\n\
\n\
result.describe() produces a DB API description of the rows.\n\
\n\
escape_row() accepts a sequence of items, converts them to strings, does\n\
escape_row() accepts a sequence of items and a type conversion dictionary.\n\
Using the type of the item, it gets a converter function from the dictionary\n\
(uses the string type if the item type is not found) and applies this to the\n\
item. the result should be converted to strings with all the necessary\n\
quoting.\n\
\n\
mysql_escape_string() on them, and returns them as a tuple.\n\
\n\
result.field_flags() returns the field flags for the result.\n\