mirror of
https://github.com/PyMySQL/mysqlclient.git
synced 2025-08-15 19:31:54 +08:00
Minor bugfix on _mysql: Set exception if non-sequence is passed to
escape_row(). Updating for version 2.0 of the API.
This commit is contained in:
104
mysql/MySQLdb.py
104
mysql/MySQLdb.py
@ -1,16 +1,56 @@
|
||||
import _mysql
|
||||
from _mysql import *
|
||||
from DateTime import Date, Time, Timestamp, ISO
|
||||
from time import localtime
|
||||
import re
|
||||
|
||||
threadsafety = 1
|
||||
apllevel = "1.1"
|
||||
paramstyle = "percent"
|
||||
apllevel = "2.0"
|
||||
paramstyle = "format"
|
||||
|
||||
def DateFromTicks(ticks):
|
||||
return apply(Date, localtime(ticks)[:3])
|
||||
|
||||
def TimeFromTicks(ticks):
|
||||
return apply(Time, localtime(ticks)[3:6])
|
||||
|
||||
def TimestampFromTicks(ticks):
|
||||
return apply(Timestamp, localtime(ticks)[:6])
|
||||
|
||||
|
||||
class DBAPITypeObject:
|
||||
|
||||
def __init__(self,*values):
|
||||
self.values = values
|
||||
|
||||
def __cmp__(self,other):
|
||||
if other in self.values:
|
||||
return 0
|
||||
if other < self.values:
|
||||
return 1
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
Set = DBAPITypeObject
|
||||
|
||||
STRING = Set(FIELD_TYPE.CHAR, FIELD_TYPE.ENUM, FIELD_TYPE.INTERVAL,
|
||||
FIELD_TYPE.SET, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING)
|
||||
BINARY = Set(FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB, FIELD_TYPE.MEDIUM_BLOB,
|
||||
FIELD_TYPE.TINY_BLOB)
|
||||
NUMBER = Set(FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT,
|
||||
FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG,
|
||||
FIELD_TYPE.TINY, FIELD_TYPE.YEAR)
|
||||
DATE = Set(FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE)
|
||||
TIME = Set(FIELD_TYPE.TIME)
|
||||
TIMESTAMP = Set(FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME)
|
||||
ROWID = Set()
|
||||
|
||||
def Binary(x): return str(x)
|
||||
|
||||
def DATE(d): return d.Format("%Y-%m-%d")
|
||||
def TIME(d): return d.Format("%H:%M:%S")
|
||||
def TIMESTAMP(d): return d.Format("%Y-%m-%d %H:%M:%S")
|
||||
def format_DATE(d): return d.Format("%Y-%m-%d")
|
||||
def format_TIME(d): return d.Format("%H:%M:%S")
|
||||
def format_TIMESTAMP(d): return d.Format("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def mysql_timestamp_converter(s):
|
||||
parts = map(int, filter(None, (s[:4],s[4:6],s[6:8],s[8:10],s[10:12],s[12:14])))
|
||||
@ -21,6 +61,12 @@ type_conv[FIELD_TYPE.DATETIME] = ISO.ParseDateTime
|
||||
type_conv[FIELD_TYPE.TIME] = ISO.ParseTime
|
||||
type_conv[FIELD_TYPE.DATE] = ISO.ParseDate
|
||||
|
||||
insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE)
|
||||
|
||||
def escape_dict(d):
|
||||
d2 = {}
|
||||
for k,v in d.items(): d2[k] = "'%s'" % escape_string(str(v))
|
||||
return d2
|
||||
|
||||
class Cursor:
|
||||
|
||||
@ -34,28 +80,37 @@ class Cursor:
|
||||
self.warnings = warnings
|
||||
self.use = use
|
||||
|
||||
def setinputsizes(self, size): pass
|
||||
def setinputsizes(self, *args): pass
|
||||
|
||||
def setoutputsizes(self, size): pass
|
||||
def setoutputsizes(self, *args): pass
|
||||
|
||||
def execute(self, query, args=None):
|
||||
from types import TupleType
|
||||
from types import ListType, TupleType
|
||||
from string import rfind, join, split, atoi
|
||||
if not args:
|
||||
self._query(query)
|
||||
elif type(args) is TupleType:
|
||||
self._query(query % escape_row(args))
|
||||
else:
|
||||
self.executemany(query, args) # deprecated
|
||||
elif type(args) is ListType and type(args[0]) is TupleType:
|
||||
self.executemany(query, args) # deprecated
|
||||
else:
|
||||
try:
|
||||
self._query(query % escape_row(args))
|
||||
except TypeError:
|
||||
self._query(query % escape_dict(args))
|
||||
|
||||
def executemany(self, query, args=None):
|
||||
from string import rfind, join
|
||||
p = rfind(query, '(')
|
||||
if p == -1: raise ProgrammingError, "can't find values"
|
||||
from string import join
|
||||
m = insert_values(query)
|
||||
if not m: raise ProgrammingError, "can't find values"
|
||||
p = m.start(1)
|
||||
n = len(args)-1
|
||||
q = [query % escape_row(args[0])]
|
||||
escape = escape_row
|
||||
try:
|
||||
q = [query % escape(args[0])]
|
||||
except TypeError:
|
||||
escape = escape_dict
|
||||
q = [query % escape(args[0])]
|
||||
qv = query[p:]
|
||||
for a in args[1:]: q.append(qv % escape_row(a))
|
||||
for a in args[1:]: q.append(qv % escape(a))
|
||||
self._query(join(q, ',\n'))
|
||||
|
||||
def _query(self, q):
|
||||
@ -89,30 +144,21 @@ class Cursor:
|
||||
|
||||
def fetchall(self): return self.result.fetch_all_rows()
|
||||
|
||||
def nextset(self): pass
|
||||
def nextset(self): return None
|
||||
|
||||
|
||||
class Connection:
|
||||
|
||||
CursorClass = Cursor
|
||||
|
||||
def __init__(self, dsn=None, user=None, password=None,
|
||||
host=None, database=None, **kwargs):
|
||||
newargs = {}
|
||||
if user: newargs['user'] = user
|
||||
if password: newargs['passwd'] = password
|
||||
if host: newargs['host'] = host
|
||||
if database: newargs['db'] = database
|
||||
newargs.update(kwargs)
|
||||
self.db = apply(_mysql.connect, (), newargs)
|
||||
def __init__(self, **kwargs):
|
||||
self.db = apply(_mysql.connect, (), kwargs)
|
||||
|
||||
def close(self):
|
||||
self.db.close()
|
||||
|
||||
def commit(self): pass
|
||||
|
||||
# def rollback(self): raise OperationalError, "transactions not supported"
|
||||
|
||||
def cursor(self, name=''):
|
||||
return self.CursorClass(self, name)
|
||||
|
||||
|
Reference in New Issue
Block a user