From d4529e5d15b741179ad3ec29b658fee2ce171b31 Mon Sep 17 00:00:00 2001 From: adustman Date: Fri, 24 Mar 2000 05:46:00 +0000 Subject: [PATCH] Break up the various cursor variations into MixIn classes. Found a work-around for the way MySQL treats TIME literals with dates and/or fractional seconds. Added a mutex into the cursor so that connections can be shared between threads. threadsafety=2 I could easily make the cursors sharable as well (threadsafety=3) but I hardly see the point. Even sharing connections is not a good idea, because you don't get the benefit of multiple mysqld threads. --- mysql/MySQLdb.py | 274 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 195 insertions(+), 79 deletions(-) diff --git a/mysql/MySQLdb.py b/mysql/MySQLdb.py index 7b36e19..8e1bd94 100644 --- a/mysql/MySQLdb.py +++ b/mysql/MySQLdb.py @@ -21,10 +21,17 @@ from _mysql import * from time import localtime import re, types -threadsafety = 1 +threadsafety = 2 apilevel = "2.0" paramstyle = "format" +try: + import threading + _threading = threading + del threading +except ImportError: + _threading = None + def Long2Int(l): return str(l)[:-1] # drop the trailing L def None2NULL(d): return "NULL" def String2literal(s): return "'%s'" % escape_string(str(s)) @@ -45,7 +52,8 @@ type_conv = { FIELD_TYPE.TINY: int, FIELD_TYPE.YEAR: int } try: - from DateTime import Date, Time, Timestamp, ISO, DateTimeType + from DateTime import Date, Time, Timestamp, ISO, \ + DateTimeType, DateTimeDeltaType def DateFromTicks(ticks): return apply(Date, localtime(ticks)[:3]) @@ -56,9 +64,9 @@ try: def TimestampFromTicks(ticks): return apply(Timestamp, localtime(ticks)[:6]) - def format_DATE(d): return d.Format("%Y-%m-%d") - def format_TIME(d): return d.Format("%H:%M:%S") - def format_TIMESTAMP(d): return d.Format("%Y-%m-%d %H:%M:%S") + def format_DATE(d): return d.strftime("%Y-%m-%d") + def format_TIME(d): return d.strftime("%H:%M:%S") + def format_TIMESTAMP(d): return d.strftime("%Y-%m-%d %H:%M:%S") def mysql_timestamp_converter(s): parts = map(int, filter(None, (s[:4],s[4:6],s[6:8], @@ -67,12 +75,14 @@ try: type_conv[FIELD_TYPE.TIMESTAMP] = mysql_timestamp_converter type_conv[FIELD_TYPE.DATETIME] = ISO.ParseDateTime - #type_conv[FIELD_TYPE.TIME] = ISO.ParseTime + type_conv[FIELD_TYPE.TIME] = ISO.ParseTime type_conv[FIELD_TYPE.DATE] = ISO.ParseDate - #def DateTime2literal(d): return "'%s'" % format_TIMESTAMP(d) + def DateTime2literal(d): return "'%s'" % format_TIMESTAMP(d) + def DateTimeDelta2literal(d): return "'%s'" % format_TIME(d) - #quote_conv[DateTimeType] = DateTime2literal + quote_conv[DateTimeType] = DateTime2literal + quote_conv[DateTimeDeltaType] = DateTimeDelta2literal except ImportError: # no DateTime? We'll muddle through somehow. @@ -130,51 +140,56 @@ def escape_dict(d, qc): return d2 -class _Cursor: +class BaseCursor: - """Created by a Connection object. Useful attributes: + """A base for Cursor classes. Useful attributes: description -- DB API 7-tuple describing columns in last query arraysize -- default number of rows fetchmany() will fetch - warnings -- should MySQL warnings raise a Warning exception? - use -- should mysql_use_result be used instead of mysql_store_result? - By default, warnings are issued, and mysql_store_result is used. See the MySQL docs for more information.""" - def __init__(self, connection, name='', use=0, warnings=1): + def __init__(self, connection, warnings=1): self.connection = connection - self.name = name self.description = None self.rowcount = -1 self.result = None self.arraysize = 100 self.warnings = warnings - self.use = use + + def close(self): + self.connection = None + + def _check_open(self): + if not self.connection: + raise ProgrammingError, "cursor closed" def setinputsizes(self, *args): pass def setoutputsizes(self, *args): pass def execute(self, query, args=None): - """cursor.execute(query, args=None) + """rows=cursor.execute(query, args=None) query -- string, query to execute on server - args -- sequence or mapping, parameters to use with query.""" + args -- sequence or mapping, parameters to use with query. + rows -- rows affected, if any""" + self._check_open() from types import ListType, TupleType from string import rfind, join, split, atoi qc = self.connection.quote_conv if not args: - self._query(query) + return self._query(query) elif type(args) is ListType and type(args[0]) is TupleType: - self.executemany(query, args) # deprecated + return self.executemany(query, args) # deprecated else: try: - self._query(query % escape_row(args, qc)) + return self._query(query % escape_row(args, qc)) except TypeError, m: - if m.args[0] == "not enough arguments for format string": raise - if m.args[0] == "not all arguments converted": raise - self._query(query % escape_dict(args, qc)) + if m.args[0] in ("not enough arguments for format string", + "not all arguments converted"): + raise ProgrammingError, m.args[0] + return self._query(query % escape_dict(args, qc)) def executemany(self, query, args): """cursor.executemany(self, query, args) @@ -186,6 +201,7 @@ class _Cursor: item in the sequence. This method performs multiple-row inserts and similar queries.""" + self._check_open() from string import join m = insert_values.search(query) if not m: raise ProgrammingError, "can't find values" @@ -195,33 +211,117 @@ class _Cursor: try: q = [query % escape(args[0], qc)] except TypeError, m: - if m.args[0] == "not enough arguments for format string": raise - if m.args[0] == "not all arguments converted": raise + if m.args[0] in ("not enough arguments for format string", + "not all arguments converted"): + raise ProgrammingError, m.args[0] escape = escape_dict q = [query % escape(args[0], qc)] qv = query[p:] for a in args[1:]: q.append(qv % escape(a, qc)) - self._query(join(q, ',\n')) + return self._query(join(q, ',\n')) - def _query(self, q): + def _do_query(self, q): from string import split, atoi db = self.connection.db db.query(q) - if self.use: self.result = db.use_result() - else: self.result = db.store_result() + self.result = self._get_result() self.rowcount = db.affected_rows() self.description = self.result and self.result.describe() or None - if self.warnings: - w = db.info() - if w: - warnings = atoi(split(w)[-1]) - if warnings: - raise Warning, w + self.__insert_id = db.insert_id() + return self.rowcount + + _query = _do_query + + def insert_id(self): + try: return self.__insert_id + except AttributeError: raise ProgrammingError, "execute() first" + + def nextset(self): return None + + +class CursorWarningMixIn: + + def _query(self, q): + from string import atoi, split + r = self._do_query(q) + w = self.connection.db.info() + if w: + warnings = atoi(split(w)[-1]) + if warnings: + raise Warning, w + return r + + +class CursorStoreResultMixIn: + + def _get_result(self): return self.connection.db.store_result() + + def _do_query(self, q): + self.connection._acquire() + try: + BaseCursor._do_query(self, q) + self.__rows = self.result and self._fetch_all_rows() or ((),) + self.__pos = 0 + del self.result + finally: + self.connection._release() + + def fetchone(self): + """Fetches a single row from the cursor.""" + result = self.__rows[self.__pos] + self.__pos = self.__pos+1 + return result + + def fetchmany(self, size=None): + """cursor.fetchmany(size=cursor.arraysize) + + size -- integer, maximum number of rows to fetch.""" + end = self.__pos + size or self.arraysize + result = self.__rows[self.__pos:end] + self.__pos = end + return result + + def fetchall(self): + """Fetchs all available rows from the cursor.""" + result = self.__rows[self.__pos:] + self.__pos = len(self.__rows) + return result + + def seek(self, row, whence=0): + if whence == 0: + self.__pos = row + elif whence == 1: + self.__pos = self.__pos + row + elif whence == 2: + self.__pos = len(self.__rows) + row + + def tell(self): return self.__pos + + +class CursorUseResultMixIn: + + def __init__(self, name=""): + BaseCursor.__init__(self, name="") + if not self.connection._acquire(0): + raise ProgrammingError, "would deadlock" + + def close(self): + self.connection._release() + self.connection = None + + def __del__(self): + try: + del self.result + finally: + self.close() + + def _get_result(self): return self.connection.db.use_result() def fetchone(self): """Fetches a single row from the cursor.""" + self._check_open() try: - return self.result.fetch_row() + return self._fetch_row() except AttributeError: raise ProgrammingError, "no query executed yet" @@ -229,41 +329,50 @@ class _Cursor: """cursor.fetchmany(size=cursor.arraysize) size -- integer, maximum number of rows to fetch.""" - return self.result.fetch_rows(size or self.arraysize) + self._check_open() + return self._fetch_rows(size or self.arraysize) def fetchall(self): """Fetchs all available rows from the cursor.""" - return self.result.fetch_all_rows() - - def fetchoneDict(self): - """Fetches a single row from the cursor as a dictionary.""" - try: - return self.result.fetch_row_as_dict() - except AttributeError: - raise ProgrammingError, "no query executed yet" - - def fetchmanyDict(self, size=None): - """cursor.fetchmany(size=cursor.arraysize) - - size -- integer, maximum number of rows to fetch. - rows are returned as dictionaries.""" - return self.result.fetch_rows_as_dict(size or self.arraysize) - - def fetchallDict(self): - """Fetchs all available rows from the cursor as dictionaries.""" - return self.result.fetch_all_rows_as_dict() - - def nextset(self): return None + self._check_open() + return self._fetch_all_rows() + + +class CursorTupleRowsMixIn: + + def _fetch_row(self): return self.result.fetch_row() + def _fetch_rows(self, size): return self.result.fetch_rows(size) + def _fetch_all_rows(self): return self.result.fetch_all_rows() + + +class CursorDictRowsMixIn: + + def _fetch_row(self): return self.result.fetch_row_as_dict() + def _fetch_rows(self, size): return self.result.fetch_rows_as_dict(size) + def _fetch_all_rows(self): return self.result.fetch_all_rows_as_dict() + + ## XXX Deprecated + + def fetchoneDict(self, *args, **kwargs): + return apply(self.fetchone, args, kwargs) + + def fetchmanyDict(self, *args, **kwargs): + return apply(self.fetchmany, args, kwargs) + + def fetchallDict(self, *args, **kwargs): + return apply(self.fetchall, args, kwargs) + + +class Cursor(CursorWarningMixIn, CursorStoreResultMixIn, + CursorTupleRowsMixIn, BaseCursor): pass +class DictCursor(CursorWarningMixIn, CursorStoreResultMixIn, + CursorDictRowsMixIn, BaseCursor): pass +class SSCursor(CursorWarningMixIn, CursorUseResultMixIn, + CursorTupleRowsMixIn, BaseCursor): pass +class SSDictCursor(CursorWarningMixIn, CursorUseResultMixIn, + CursorDictRowsMixIn, BaseCursor): pass - def seek(self, row, whence=0): - if self.use: raise NotSupportedError, "use must be 0 to use seek" - if whence: raise NotSupportedError, "can't do relative seek" - return self.result.data_seek(row) - ## return whence and self.result.row_seek(row) or self.result.data_seek(row) - - ## def tell(self): return self.result.row_tell() - class Connection: """Connection(host=NULL, user=NULL, passwd=NULL, db=NULL, @@ -289,18 +398,27 @@ class Connection: MySQL-specific calls. close -- close the connection. cursor -- create a cursor (emulated) for executing queries. - CursorClass -- class used to create cursors (_Cursor). If you subclass - the Connection object, you will probably want to override this. """ - CursorClass = _Cursor - def __init__(self, **kwargs): from _mysql import connect if not kwargs.has_key('conv'): kwargs['conv'] = type_conv.copy() self.quote_conv = kwargs.get('quote_conv', quote_conv.copy()) + if kwargs.has_key('cursorclass'): + self.cursorclass = kwargs['cursorclass'] + del kwargs['cursorclass'] + else: + self.cursorclass = Cursor self.db = apply(connect, (), kwargs) - + if _threading: self.__lock = _threading.Lock() + + if _threading: + def _acquire(self, blocking=1): return self.__lock.acquire(blocking) + def _release(self): return self.__lock.release() + else: + def _acquire(self, blocking=1): return 1 + def _release(self): return 1 + def close(self): """Close the connection. No further activity possible.""" self.db.close() @@ -309,17 +427,16 @@ class Connection: def commit(self): """Commit the current transaction.""" return self.db.commit() - else: - def commit(self): """Does nothing as there are no transactions.""" - - if hasattr(_mysql, 'transactions'): + def rollback(self): """Rollback the current transaction.""" self.db.rollback() + else: + def commit(self): """Does nothing as there are no transactions.""" - def cursor(self, *args, **kwargs): + def cursor(self, cursorclass=None): """Create a cursor on which queries may be performed.""" - return apply(self.CursorClass, (self,)+args, kwargs) + return (cursorclass or self.cursorclass)(self) # Non-portable MySQL-specific stuff # Methods not included on purpose (use Cursors instead): @@ -331,7 +448,6 @@ class Connection: def get_proto_info(self): return self.db.get_proto_info() def get_server_info(self): return self.db.get_server_info() def info(self): return self.db.info() - def insert_id(self): return self.db.insert_id() def kill(self, p): return self.db.kill(p) def list_dbs(self): return self.db.list_dbs().fetch_all_rows() def list_fields(self, table): return self.db.list_fields(table).fetch_all_rows()