mirror of
https://github.com/PyMySQL/mysqlclient.git
synced 2025-08-15 19:31:54 +08:00
Better support for new features.
Improved quoting.
This commit is contained in:
@ -32,21 +32,24 @@ try:
|
||||
except ImportError:
|
||||
_threading = None
|
||||
|
||||
def Long2Int(l): s = str(l); return s[-1] == 'L' and s[:-1] or s
|
||||
def None2NULL(d): return "NULL"
|
||||
def Thing2Literal(o): return string_literal(str(o))
|
||||
def Thing2Str(s, d={}): return str(s)
|
||||
def Long2Int(l, d={}): s = str(l); return s[-1] == 'L' and s[:-1] or s
|
||||
def None2NULL(o, d={}): return "NULL"
|
||||
def Thing2Literal(o, d={}): return string_literal(str(o))
|
||||
|
||||
# MySQL-3.23.xx now has a new escape_string function that uses
|
||||
# the connection to determine what character set is in use and
|
||||
# quote accordingly. So this will be overridden by the connect()
|
||||
# method.
|
||||
String2Literal = string_literal
|
||||
|
||||
quote_conv = { types.IntType: str,
|
||||
quote_conv = { types.IntType: Thing2Str,
|
||||
types.LongType: Long2Int,
|
||||
types.FloatType: str,
|
||||
types.FloatType: Thing2Str,
|
||||
types.NoneType: None2NULL,
|
||||
types.StringType: String2Literal }
|
||||
types.TupleType: escape_sequence,
|
||||
types.ListType: escape_sequence,
|
||||
types.DictType: escape_dict,
|
||||
types.StringType: Thing2Literal } # default
|
||||
|
||||
type_conv = { FIELD_TYPE.TINY: int,
|
||||
FIELD_TYPE.SHORT: int,
|
||||
@ -83,11 +86,11 @@ try:
|
||||
|
||||
type_conv[FIELD_TYPE.TIMESTAMP] = mysql_timestamp_converter
|
||||
type_conv[FIELD_TYPE.DATETIME] = ISO.ParseDateTime
|
||||
type_conv[FIELD_TYPE.TIME] = ISO.ParseTime
|
||||
type_conv[FIELD_TYPE.TIME] = ISO.ParseTimeDelta
|
||||
type_conv[FIELD_TYPE.DATE] = ISO.ParseDate
|
||||
|
||||
def DateTime2literal(d): return "'%s'" % format_TIMESTAMP(d)
|
||||
def DateTimeDelta2literal(d): return "'%s'" % format_TIME(d)
|
||||
def DateTime2literal(d, c={}): return "'%s'" % format_TIMESTAMP(d)
|
||||
def DateTimeDelta2literal(d, c={}): return "'%s'" % format_TIME(d)
|
||||
|
||||
quote_conv[DateTimeType] = DateTime2literal
|
||||
quote_conv[DateTimeDeltaType] = DateTimeDelta2literal
|
||||
@ -191,12 +194,13 @@ class BaseCursor:
|
||||
return self.executemany(query, args) # deprecated
|
||||
else:
|
||||
try:
|
||||
return self._query(query % escape_row(args, qc))
|
||||
return self._query(query % escape(args, qc))
|
||||
except TypeError, m:
|
||||
if m.args[0] in ("not enough arguments for format string",
|
||||
"not all arguments converted"):
|
||||
raise ProgrammingError, m.args[0]
|
||||
return self._query(query % escape_dict(args, qc))
|
||||
else:
|
||||
raise
|
||||
|
||||
def executemany(self, query, args):
|
||||
"""cursor.executemany(self, query, args)
|
||||
@ -213,7 +217,6 @@ class BaseCursor:
|
||||
m = insert_values.search(query)
|
||||
if not m: raise ProgrammingError, "can't find values"
|
||||
p = m.start(1)
|
||||
escape = escape_row
|
||||
qc = self.connection.quote_conv
|
||||
try:
|
||||
q = [query % escape(args[0], qc)]
|
||||
@ -221,8 +224,8 @@ class BaseCursor:
|
||||
if msg.args[0] in ("not enough arguments for format string",
|
||||
"not all arguments converted"):
|
||||
raise ProgrammingError, msg.args[0]
|
||||
escape = escape_dict
|
||||
q = [query % escape(args[0], qc)]
|
||||
else:
|
||||
raise
|
||||
qv = query[p:]
|
||||
for a in args[1:]: q.append(qv % escape(a, qc))
|
||||
return self._query(join(q, ',\n'))
|
||||
@ -440,6 +443,7 @@ class Connection:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
from _mysql import connect
|
||||
from string import split
|
||||
if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy()
|
||||
self.quote_conv = kwargs.get('quote_conv', quote_conv.copy())
|
||||
if kwargs.has_key('cursorclass'):
|
||||
@ -448,9 +452,10 @@ class Connection:
|
||||
else:
|
||||
self.cursorclass = Cursor
|
||||
self.db = apply(connect, (), kwargs)
|
||||
self.quote_conv[types.StringType] = self.db.string_literal
|
||||
self._server_info = i = self.db.get_server_info()
|
||||
self._server_version = int(i[0])*10000 + int(i[2:4])*100 + int(i[5:7])
|
||||
self.quote_conv[types.StringType] = self.Thing2Literal
|
||||
self._server_info = self.db.get_server_info()
|
||||
i = map(int, split(split(self._server_info, '-')[0],'.'))
|
||||
self._server_version = i[0]*10000 + i[1]*100 + i[2]
|
||||
if _threading: self.__lock = _threading.Lock()
|
||||
|
||||
if _threading:
|
||||
@ -460,6 +465,8 @@ class Connection:
|
||||
def _acquire(self, blocking=1): return 1
|
||||
def _release(self): return 1
|
||||
|
||||
def Thing2Literal(self, o, d={}): return self.db.string_literal(str(o))
|
||||
|
||||
def close(self):
|
||||
"""Close the connection. No further activity possible."""
|
||||
self.db.close()
|
||||
@ -503,11 +510,15 @@ class Connection:
|
||||
def string_literal(self, s): return self.db.string_literal(s)
|
||||
def thread_id(self): return self.db.thread_id()
|
||||
|
||||
def change_user(self, *args, **kwargs):
|
||||
def _try_feature(self, feature, *args, **kwargs):
|
||||
try:
|
||||
return apply(self.db.change_user, args, kwargs)
|
||||
return apply(getattr(self.db, feature), args, kwargs)
|
||||
except AttributeError:
|
||||
raise NotSupportedError, "not supported by MySQL library"
|
||||
def character_set_name(self):
|
||||
return self._try_feature('character_set_name')
|
||||
def change_user(self, *args, **kwargs):
|
||||
return apply(self._try_feature, ('change_user',)+args, kwargs)
|
||||
|
||||
|
||||
Connect = connect = Connection
|
||||
|
Reference in New Issue
Block a user