mirror of
https://github.com/PyMySQL/mysqlclient.git
synced 2025-08-15 19:31:54 +08:00
Add a dictionary that controls conversion of Python types into
MySQL literals.
This commit is contained in:
@ -18,12 +18,22 @@ __version__ = """$Revision$"""[11:-2]
|
||||
|
||||
from _mysql import *
|
||||
from time import localtime
|
||||
import re
|
||||
import re, types
|
||||
|
||||
threadsafety = 1
|
||||
apilevel = "2.0"
|
||||
paramstyle = "format"
|
||||
|
||||
def Long2Int(l): return str(l)[:-1] # drop the trailing L
|
||||
def None2NULL(d): return "NULL"
|
||||
def String2literal(s): return "'%s'" % escape_string(str(s))
|
||||
|
||||
quote_conv = { types.IntType: str,
|
||||
types.LongType: Long2Int,
|
||||
types.FloatType: str,
|
||||
types.NoneType: None2NULL,
|
||||
types.StringType: String2literal }
|
||||
|
||||
type_conv = { FIELD_TYPE.TINY: int,
|
||||
FIELD_TYPE.SHORT: int,
|
||||
FIELD_TYPE.LONG: int,
|
||||
@ -34,7 +44,7 @@ type_conv = { FIELD_TYPE.TINY: int,
|
||||
FIELD_TYPE.YEAR: int }
|
||||
|
||||
try:
|
||||
from DateTime import Date, Time, Timestamp, ISO
|
||||
from DateTime import Date, Time, Timestamp, ISO, DateTimeType
|
||||
|
||||
def DateFromTicks(ticks):
|
||||
return apply(Date, localtime(ticks)[:3])
|
||||
@ -59,6 +69,10 @@ try:
|
||||
#type_conv[FIELD_TYPE.TIME] = ISO.ParseTime
|
||||
type_conv[FIELD_TYPE.DATE] = ISO.ParseDate
|
||||
|
||||
#def DateTime2literal(d): return "'%s'" % format_TIMESTAMP(d)
|
||||
|
||||
#quote_conv[DateTimeType] = DateTime2literal
|
||||
|
||||
except ImportError:
|
||||
# no DateTime? We'll muddle through somehow.
|
||||
from time import strftime
|
||||
@ -108,11 +122,10 @@ def Binary(x): return str(x)
|
||||
|
||||
insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE)
|
||||
|
||||
def escape_dict(d):
|
||||
def escape_dict(d, qc):
|
||||
d2 = {}
|
||||
for k,v in d.items():
|
||||
if v is None: d2[k] = "NULL"
|
||||
else: d2[k] = "'%s'" % escape_string(str(v))
|
||||
d2[k] = qc.get(type(v), String2literal)(v)
|
||||
return d2
|
||||
|
||||
|
||||
@ -149,15 +162,16 @@ class _Cursor:
|
||||
args -- sequence or mapping, parameters to use with query."""
|
||||
from types import ListType, TupleType
|
||||
from string import rfind, join, split, atoi
|
||||
qc = self.connection.quote_conv
|
||||
if not args:
|
||||
self._query(query)
|
||||
elif type(args) is ListType and type(args[0]) is TupleType:
|
||||
self.executemany(query, args) # deprecated
|
||||
else:
|
||||
try:
|
||||
self._query(query % escape_row(args))
|
||||
self._query(query % escape_row(args, qc))
|
||||
except TypeError:
|
||||
self._query(query % escape_dict(args))
|
||||
self._query(query % escape_dict(args, qc))
|
||||
|
||||
def executemany(self, query, args):
|
||||
"""cursor.executemany(self, query, args)
|
||||
@ -174,13 +188,14 @@ class _Cursor:
|
||||
if not m: raise ProgrammingError, "can't find values"
|
||||
p = m.start(1)
|
||||
escape = escape_row
|
||||
qc = self.connection.quote_conv
|
||||
try:
|
||||
q = [query % escape(args[0])]
|
||||
q = [query % escape(args[0], qc)]
|
||||
except TypeError:
|
||||
escape = escape_dict
|
||||
q = [query % escape(args[0])]
|
||||
q = [query % escape(args[0], qc)]
|
||||
qv = query[p:]
|
||||
for a in args[1:]: q.append(qv % escape(a))
|
||||
for a in args[1:]: q.append(qv % escape(a, qc))
|
||||
self._query(join(q, ',\n'))
|
||||
|
||||
def _query(self, q):
|
||||
@ -252,6 +267,7 @@ class Connection:
|
||||
def __init__(self, **kwargs):
|
||||
from _mysql import connect
|
||||
if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy()
|
||||
self.quote_conv = kwargs.get('quote_conv', quote_conv.copy())
|
||||
self.db = apply(connect, (), kwargs)
|
||||
|
||||
def close(self):
|
||||
|
Reference in New Issue
Block a user