Minor bugfix on _mysql: Set exception if non-sequence is passed to

escape_row().

Updating for version 2.0 of the API.
This commit is contained in:
adustman
1999-03-30 07:59:51 +00:00
parent 55d15aeb83
commit 675b285a7b
2 changed files with 79 additions and 30 deletions

View File

@ -1,16 +1,56 @@
import _mysql import _mysql
from _mysql import * from _mysql import *
from DateTime import Date, Time, Timestamp, ISO from DateTime import Date, Time, Timestamp, ISO
from time import localtime
import re
threadsafety = 1 threadsafety = 1
apllevel = "1.1" apllevel = "2.0"
paramstyle = "percent" paramstyle = "format"
def DateFromTicks(ticks):
return apply(Date, localtime(ticks)[:3])
def TimeFromTicks(ticks):
return apply(Time, localtime(ticks)[3:6])
def TimestampFromTicks(ticks):
return apply(Timestamp, localtime(ticks)[:6])
class DBAPITypeObject:
def __init__(self,*values):
self.values = values
def __cmp__(self,other):
if other in self.values:
return 0
if other < self.values:
return 1
else:
return -1
Set = DBAPITypeObject
STRING = Set(FIELD_TYPE.CHAR, FIELD_TYPE.ENUM, FIELD_TYPE.INTERVAL,
FIELD_TYPE.SET, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING)
BINARY = Set(FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB, FIELD_TYPE.MEDIUM_BLOB,
FIELD_TYPE.TINY_BLOB)
NUMBER = Set(FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT,
FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG,
FIELD_TYPE.TINY, FIELD_TYPE.YEAR)
DATE = Set(FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE)
TIME = Set(FIELD_TYPE.TIME)
TIMESTAMP = Set(FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME)
ROWID = Set()
def Binary(x): return str(x) def Binary(x): return str(x)
def DATE(d): return d.Format("%Y-%m-%d") def format_DATE(d): return d.Format("%Y-%m-%d")
def TIME(d): return d.Format("%H:%M:%S") def format_TIME(d): return d.Format("%H:%M:%S")
def TIMESTAMP(d): return d.Format("%Y-%m-%d %H:%M:%S") def format_TIMESTAMP(d): return d.Format("%Y-%m-%d %H:%M:%S")
def mysql_timestamp_converter(s): def mysql_timestamp_converter(s):
parts = map(int, filter(None, (s[:4],s[4:6],s[6:8],s[8:10],s[10:12],s[12:14]))) parts = map(int, filter(None, (s[:4],s[4:6],s[6:8],s[8:10],s[10:12],s[12:14])))
@ -21,6 +61,12 @@ type_conv[FIELD_TYPE.DATETIME] = ISO.ParseDateTime
type_conv[FIELD_TYPE.TIME] = ISO.ParseTime type_conv[FIELD_TYPE.TIME] = ISO.ParseTime
type_conv[FIELD_TYPE.DATE] = ISO.ParseDate type_conv[FIELD_TYPE.DATE] = ISO.ParseDate
insert_values = re.compile(r'values\s(\(.+\))', re.IGNORECASE)
def escape_dict(d):
d2 = {}
for k,v in d.items(): d2[k] = "'%s'" % escape_string(str(v))
return d2
class Cursor: class Cursor:
@ -34,28 +80,37 @@ class Cursor:
self.warnings = warnings self.warnings = warnings
self.use = use self.use = use
def setinputsizes(self, size): pass def setinputsizes(self, *args): pass
def setoutputsizes(self, size): pass def setoutputsizes(self, *args): pass
def execute(self, query, args=None): def execute(self, query, args=None):
from types import TupleType from types import ListType, TupleType
from string import rfind, join, split, atoi from string import rfind, join, split, atoi
if not args: if not args:
self._query(query) self._query(query)
elif type(args) is TupleType: elif type(args) is ListType and type(args[0]) is TupleType:
self._query(query % escape_row(args)) self.executemany(query, args) # deprecated
else: else:
self.executemany(query, args) # deprecated try:
self._query(query % escape_row(args))
except TypeError:
self._query(query % escape_dict(args))
def executemany(self, query, args=None): def executemany(self, query, args=None):
from string import rfind, join from string import join
p = rfind(query, '(') m = insert_values(query)
if p == -1: raise ProgrammingError, "can't find values" if not m: raise ProgrammingError, "can't find values"
p = m.start(1)
n = len(args)-1 n = len(args)-1
q = [query % escape_row(args[0])] escape = escape_row
try:
q = [query % escape(args[0])]
except TypeError:
escape = escape_dict
q = [query % escape(args[0])]
qv = query[p:] qv = query[p:]
for a in args[1:]: q.append(qv % escape_row(a)) for a in args[1:]: q.append(qv % escape(a))
self._query(join(q, ',\n')) self._query(join(q, ',\n'))
def _query(self, q): def _query(self, q):
@ -89,30 +144,21 @@ class Cursor:
def fetchall(self): return self.result.fetch_all_rows() def fetchall(self): return self.result.fetch_all_rows()
def nextset(self): pass def nextset(self): return None
class Connection: class Connection:
CursorClass = Cursor CursorClass = Cursor
def __init__(self, dsn=None, user=None, password=None, def __init__(self, **kwargs):
host=None, database=None, **kwargs): self.db = apply(_mysql.connect, (), kwargs)
newargs = {}
if user: newargs['user'] = user
if password: newargs['passwd'] = password
if host: newargs['host'] = host
if database: newargs['db'] = database
newargs.update(kwargs)
self.db = apply(_mysql.connect, (), newargs)
def close(self): def close(self):
self.db.close() self.db.close()
def commit(self): pass def commit(self): pass
# def rollback(self): raise OperationalError, "transactions not supported"
def cursor(self, name=''): def cursor(self, name=''):
return self.CursorClass(self, name) return self.CursorClass(self, name)

View File

@ -508,7 +508,10 @@ _mysql_escape_row(self, args)
char *in, *out; char *in, *out;
int i, n, len, size; int i, n, len, size;
if (!PyArg_ParseTuple(args, "O:escape_row", &o)) goto error2; if (!PyArg_ParseTuple(args, "O:escape_row", &o)) goto error2;
if (!PySequence_Check(o)) goto error2; if (!PySequence_Check(o)) {
PyErr_SetString(PyExc_TypeError, "sequence required");
goto error2;
}
if (!(n = PyObject_Length(o))) goto error2; if (!(n = PyObject_Length(o))) goto error2;
if (!(r = PyTuple_New(n))) goto error; if (!(r = PyTuple_New(n))) goto error;
for (i=0; i<n; i++) { for (i=0; i<n; i++) {