Add a dictionary that controls conversion of Python types into

MySQL literals.
This commit is contained in:
adustman
2000-01-07 00:30:44 +00:00
parent 589df8797c
commit 42bd06e986
2 changed files with 58 additions and 38 deletions

View File

@ -18,12 +18,22 @@ __version__ = """$Revision$"""[11:-2]
from _mysql import *
from time import localtime
import re
import re, types
threadsafety = 1
apilevel = "2.0"
paramstyle = "format"
def Long2Int(l): return str(l)[:-1] # drop the trailing L
def None2NULL(d): return "NULL"
def String2literal(s): return "'%s'" % escape_string(str(s))
quote_conv = { types.IntType: str,
types.LongType: Long2Int,
types.FloatType: str,
types.NoneType: None2NULL,
types.StringType: String2literal }
type_conv = { FIELD_TYPE.TINY: int,
FIELD_TYPE.SHORT: int,
FIELD_TYPE.LONG: int,
@ -34,7 +44,7 @@ type_conv = { FIELD_TYPE.TINY: int,
FIELD_TYPE.YEAR: int }
try:
from DateTime import Date, Time, Timestamp, ISO
from DateTime import Date, Time, Timestamp, ISO, DateTimeType
def DateFromTicks(ticks):
return apply(Date, localtime(ticks)[:3])
@ -59,6 +69,10 @@ try:
#type_conv[FIELD_TYPE.TIME] = ISO.ParseTime
type_conv[FIELD_TYPE.DATE] = ISO.ParseDate
#def DateTime2literal(d): return "'%s'" % format_TIMESTAMP(d)
#quote_conv[DateTimeType] = DateTime2literal
except ImportError:
# no DateTime? We'll muddle through somehow.
from time import strftime
@ -108,11 +122,10 @@ def Binary(x): return str(x)
insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE)
def escape_dict(d):
def escape_dict(d, qc):
d2 = {}
for k,v in d.items():
if v is None: d2[k] = "NULL"
else: d2[k] = "'%s'" % escape_string(str(v))
d2[k] = qc.get(type(v), String2literal)(v)
return d2
@ -149,15 +162,16 @@ class _Cursor:
args -- sequence or mapping, parameters to use with query."""
from types import ListType, TupleType
from string import rfind, join, split, atoi
qc = self.connection.quote_conv
if not args:
self._query(query)
elif type(args) is ListType and type(args[0]) is TupleType:
self.executemany(query, args) # deprecated
else:
try:
self._query(query % escape_row(args))
self._query(query % escape_row(args, qc))
except TypeError:
self._query(query % escape_dict(args))
self._query(query % escape_dict(args, qc))
def executemany(self, query, args):
"""cursor.executemany(self, query, args)
@ -174,13 +188,14 @@ class _Cursor:
if not m: raise ProgrammingError, "can't find values"
p = m.start(1)
escape = escape_row
qc = self.connection.quote_conv
try:
q = [query % escape(args[0])]
q = [query % escape(args[0], qc)]
except TypeError:
escape = escape_dict
q = [query % escape(args[0])]
q = [query % escape(args[0], qc)]
qv = query[p:]
for a in args[1:]: q.append(qv % escape(a))
for a in args[1:]: q.append(qv % escape(a, qc))
self._query(join(q, ',\n'))
def _query(self, q):
@ -252,6 +267,7 @@ class Connection:
def __init__(self, **kwargs):
from _mysql import connect
if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy()
self.quote_conv = kwargs.get('quote_conv', quote_conv.copy())
self.db = apply(connect, (), kwargs)
def close(self):