diff --git a/mysql/MySQLdb.py b/mysql/MySQLdb.py index 9fe7ba2..14d4e2e 100644 --- a/mysql/MySQLdb.py +++ b/mysql/MySQLdb.py @@ -18,12 +18,22 @@ __version__ = """$Revision$"""[11:-2] from _mysql import * from time import localtime -import re +import re, types threadsafety = 1 apilevel = "2.0" paramstyle = "format" +def Long2Int(l): return str(l)[:-1] # drop the trailing L +def None2NULL(d): return "NULL" +def String2literal(s): return "'%s'" % escape_string(str(s)) + +quote_conv = { types.IntType: str, + types.LongType: Long2Int, + types.FloatType: str, + types.NoneType: None2NULL, + types.StringType: String2literal } + type_conv = { FIELD_TYPE.TINY: int, FIELD_TYPE.SHORT: int, FIELD_TYPE.LONG: int, @@ -34,7 +44,7 @@ type_conv = { FIELD_TYPE.TINY: int, FIELD_TYPE.YEAR: int } try: - from DateTime import Date, Time, Timestamp, ISO + from DateTime import Date, Time, Timestamp, ISO, DateTimeType def DateFromTicks(ticks): return apply(Date, localtime(ticks)[:3]) @@ -59,6 +69,10 @@ try: #type_conv[FIELD_TYPE.TIME] = ISO.ParseTime type_conv[FIELD_TYPE.DATE] = ISO.ParseDate + #def DateTime2literal(d): return "'%s'" % format_TIMESTAMP(d) + + #quote_conv[DateTimeType] = DateTime2literal + except ImportError: # no DateTime? We'll muddle through somehow. from time import strftime @@ -108,11 +122,10 @@ def Binary(x): return str(x) insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE) -def escape_dict(d): +def escape_dict(d, qc): d2 = {} for k,v in d.items(): - if v is None: d2[k] = "NULL" - else: d2[k] = "'%s'" % escape_string(str(v)) + d2[k] = qc.get(type(v), String2literal)(v) return d2 @@ -149,15 +162,16 @@ class _Cursor: args -- sequence or mapping, parameters to use with query.""" from types import ListType, TupleType from string import rfind, join, split, atoi + qc = self.connection.quote_conv if not args: self._query(query) elif type(args) is ListType and type(args[0]) is TupleType: self.executemany(query, args) # deprecated else: try: - self._query(query % escape_row(args)) + self._query(query % escape_row(args, qc)) except TypeError: - self._query(query % escape_dict(args)) + self._query(query % escape_dict(args, qc)) def executemany(self, query, args): """cursor.executemany(self, query, args) @@ -174,13 +188,14 @@ class _Cursor: if not m: raise ProgrammingError, "can't find values" p = m.start(1) escape = escape_row + qc = self.connection.quote_conv try: - q = [query % escape(args[0])] + q = [query % escape(args[0], qc)] except TypeError: escape = escape_dict - q = [query % escape(args[0])] + q = [query % escape(args[0], qc)] qv = query[p:] - for a in args[1:]: q.append(qv % escape(a)) + for a in args[1:]: q.append(qv % escape(a, qc)) self._query(join(q, ',\n')) def _query(self, q): @@ -252,6 +267,7 @@ class Connection: def __init__(self, **kwargs): from _mysql import connect if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy() + self.quote_conv = kwargs.get('quote_conv', quote_conv.copy()) self.db = apply(connect, (), kwargs) def close(self): diff --git a/mysql/_mysqlmodule.c b/mysql/_mysqlmodule.c index 26a1b83..6ba96c5 100644 --- a/mysql/_mysqlmodule.c +++ b/mysql/_mysqlmodule.c @@ -548,49 +548,48 @@ _mysql_escape_row(self, args) PyObject *self; PyObject *args; { - PyObject *o=NULL, *r=NULL, *item, *quoted, *str, *itemstr; + PyObject *o=NULL, *d=NULL, *r=NULL, *item, *quoted, *str, *itemstr, + *itemtype, *itemconv; char *in, *out; int i, n, len, size; - if (!PyArg_ParseTuple(args, "O:escape_row", &o)) goto error2; + if (!PyArg_ParseTuple(args, "OO:escape_row", &o, &d)) goto error; if (!PySequence_Check(o)) { PyErr_SetString(PyExc_TypeError, "sequence required"); - goto error2; + goto error; } - if (!(n = PyObject_Length(o))) goto error2; + if (!PyMapping_Check(d)) { + PyErr_SetString(PyExc_TypeError, "mapping required"); + goto error; + } + if (!(n = PyObject_Length(o))) goto error; if (!(r = PyTuple_New(n))) goto error; for (i=0; i