Add a dictionary that controls conversion of Python types into

MySQL literals.
This commit is contained in:
adustman
2000-01-07 00:30:44 +00:00
parent 589df8797c
commit 42bd06e986
2 changed files with 58 additions and 38 deletions

View File

@ -18,12 +18,22 @@ __version__ = """$Revision$"""[11:-2]
from _mysql import * from _mysql import *
from time import localtime from time import localtime
import re import re, types
threadsafety = 1 threadsafety = 1
apilevel = "2.0" apilevel = "2.0"
paramstyle = "format" paramstyle = "format"
def Long2Int(l): return str(l)[:-1] # drop the trailing L
def None2NULL(d): return "NULL"
def String2literal(s): return "'%s'" % escape_string(str(s))
quote_conv = { types.IntType: str,
types.LongType: Long2Int,
types.FloatType: str,
types.NoneType: None2NULL,
types.StringType: String2literal }
type_conv = { FIELD_TYPE.TINY: int, type_conv = { FIELD_TYPE.TINY: int,
FIELD_TYPE.SHORT: int, FIELD_TYPE.SHORT: int,
FIELD_TYPE.LONG: int, FIELD_TYPE.LONG: int,
@ -34,7 +44,7 @@ type_conv = { FIELD_TYPE.TINY: int,
FIELD_TYPE.YEAR: int } FIELD_TYPE.YEAR: int }
try: try:
from DateTime import Date, Time, Timestamp, ISO from DateTime import Date, Time, Timestamp, ISO, DateTimeType
def DateFromTicks(ticks): def DateFromTicks(ticks):
return apply(Date, localtime(ticks)[:3]) return apply(Date, localtime(ticks)[:3])
@ -59,6 +69,10 @@ try:
#type_conv[FIELD_TYPE.TIME] = ISO.ParseTime #type_conv[FIELD_TYPE.TIME] = ISO.ParseTime
type_conv[FIELD_TYPE.DATE] = ISO.ParseDate type_conv[FIELD_TYPE.DATE] = ISO.ParseDate
#def DateTime2literal(d): return "'%s'" % format_TIMESTAMP(d)
#quote_conv[DateTimeType] = DateTime2literal
except ImportError: except ImportError:
# no DateTime? We'll muddle through somehow. # no DateTime? We'll muddle through somehow.
from time import strftime from time import strftime
@ -108,11 +122,10 @@ def Binary(x): return str(x)
insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE) insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE)
def escape_dict(d): def escape_dict(d, qc):
d2 = {} d2 = {}
for k,v in d.items(): for k,v in d.items():
if v is None: d2[k] = "NULL" d2[k] = qc.get(type(v), String2literal)(v)
else: d2[k] = "'%s'" % escape_string(str(v))
return d2 return d2
@ -149,15 +162,16 @@ class _Cursor:
args -- sequence or mapping, parameters to use with query.""" args -- sequence or mapping, parameters to use with query."""
from types import ListType, TupleType from types import ListType, TupleType
from string import rfind, join, split, atoi from string import rfind, join, split, atoi
qc = self.connection.quote_conv
if not args: if not args:
self._query(query) self._query(query)
elif type(args) is ListType and type(args[0]) is TupleType: elif type(args) is ListType and type(args[0]) is TupleType:
self.executemany(query, args) # deprecated self.executemany(query, args) # deprecated
else: else:
try: try:
self._query(query % escape_row(args)) self._query(query % escape_row(args, qc))
except TypeError: except TypeError:
self._query(query % escape_dict(args)) self._query(query % escape_dict(args, qc))
def executemany(self, query, args): def executemany(self, query, args):
"""cursor.executemany(self, query, args) """cursor.executemany(self, query, args)
@ -174,13 +188,14 @@ class _Cursor:
if not m: raise ProgrammingError, "can't find values" if not m: raise ProgrammingError, "can't find values"
p = m.start(1) p = m.start(1)
escape = escape_row escape = escape_row
qc = self.connection.quote_conv
try: try:
q = [query % escape(args[0])] q = [query % escape(args[0], qc)]
except TypeError: except TypeError:
escape = escape_dict escape = escape_dict
q = [query % escape(args[0])] q = [query % escape(args[0], qc)]
qv = query[p:] qv = query[p:]
for a in args[1:]: q.append(qv % escape(a)) for a in args[1:]: q.append(qv % escape(a, qc))
self._query(join(q, ',\n')) self._query(join(q, ',\n'))
def _query(self, q): def _query(self, q):
@ -252,6 +267,7 @@ class Connection:
def __init__(self, **kwargs): def __init__(self, **kwargs):
from _mysql import connect from _mysql import connect
if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy() if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy()
self.quote_conv = kwargs.get('quote_conv', quote_conv.copy())
self.db = apply(connect, (), kwargs) self.db = apply(connect, (), kwargs)
def close(self): def close(self):

View File

@ -548,49 +548,48 @@ _mysql_escape_row(self, args)
PyObject *self; PyObject *self;
PyObject *args; PyObject *args;
{ {
PyObject *o=NULL, *r=NULL, *item, *quoted, *str, *itemstr; PyObject *o=NULL, *d=NULL, *r=NULL, *item, *quoted, *str, *itemstr,
*itemtype, *itemconv;
char *in, *out; char *in, *out;
int i, n, len, size; int i, n, len, size;
if (!PyArg_ParseTuple(args, "O:escape_row", &o)) goto error2; if (!PyArg_ParseTuple(args, "OO:escape_row", &o, &d)) goto error;
if (!PySequence_Check(o)) { if (!PySequence_Check(o)) {
PyErr_SetString(PyExc_TypeError, "sequence required"); PyErr_SetString(PyExc_TypeError, "sequence required");
goto error2; goto error;
} }
if (!(n = PyObject_Length(o))) goto error2; if (!PyMapping_Check(d)) {
PyErr_SetString(PyExc_TypeError, "mapping required");
goto error;
}
if (!(n = PyObject_Length(o))) goto error;
if (!(r = PyTuple_New(n))) goto error; if (!(r = PyTuple_New(n))) goto error;
for (i=0; i<n; i++) { for (i=0; i<n; i++) {
if (!(item = PySequence_GetItem(o, i))) goto error; if (!(item = PySequence_GetItem(o, i))) goto error;
if (item == Py_None) { if (!(itemtype = PyObject_Type(item)))
quoted = _mysql_NULL;
Py_INCREF(_mysql_NULL);
}
else {
if (!(itemstr = PyObject_Str(item)))
goto error; goto error;
if (!(in = PyString_AsString(itemstr))) { itemconv = PyObject_GetItem(d, itemtype);
Py_DECREF(itemstr); Py_DECREF(itemtype);
if (!itemconv) {
PyErr_Clear();
itemconv = PyObject_GetItem(d,
(PyObject *) &PyString_Type);
}
if (!itemconv) {
PyErr_SetString(PyExc_TypeError,
"no default type converter defined");
goto error; goto error;
} }
size = PyString_GET_SIZE(itemstr); quoted = PyObject_CallFunction(itemconv, "O", item);
str = PyString_FromStringAndSize((char *)NULL, size*2+3); Py_DECREF(itemconv);
if (!str) goto error; if (!quoted) goto error;
out = PyString_AS_STRING(str);
len = mysql_escape_string(out+1, in, size);
*out = '\'' ;
*(out+len+1) = '\'' ;
*(out+len+2) = 0;
if (_PyString_Resize(&str, len+2) < 0)
goto error;
Py_DECREF(itemstr);
quoted = str;
}
Py_DECREF(item); Py_DECREF(item);
PyTuple_SET_ITEM(r, i, quoted); PyTuple_SET_ITEM(r, i, quoted);
} }
return r; return r;
error: error:
Py_XDECREF(r); Py_XDECREF(r);
error2: Py_XDECREF(o);
Py_XDECREF(d);
return NULL; return NULL;
} }
@ -1347,7 +1346,12 @@ all converted reasonably, except DECIMAL.\n\
\n\ \n\
result.describe() produces a DB API description of the rows.\n\ result.describe() produces a DB API description of the rows.\n\
\n\ \n\
escape_row() accepts a sequence of items, converts them to strings, does\n\ escape_row() accepts a sequence of items and a type conversion dictionary.\n\
Using the type of the item, it gets a converter function from the dictionary\n\
(uses the string type if the item type is not found) and applies this to the\n\
item. the result should be converted to strings with all the necessary\n\
quoting.\n\
\n\
mysql_escape_string() on them, and returns them as a tuple.\n\ mysql_escape_string() on them, and returns them as a tuple.\n\
\n\ \n\
result.field_flags() returns the field flags for the result.\n\ result.field_flags() returns the field flags for the result.\n\