From 233217ce24dfd862a84fe8a032da6a481b71a794 Mon Sep 17 00:00:00 2001 From: adustman Date: Thu, 19 Apr 2001 17:32:53 +0000 Subject: [PATCH] Be more paranoid about operating on closed connections. --- mysql/_mysqlmodule.c | 94 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 74 insertions(+), 20 deletions(-) diff --git a/mysql/_mysqlmodule.c b/mysql/_mysqlmodule.c index 622599f..fa7e524 100644 --- a/mysql/_mysqlmodule.c +++ b/mysql/_mysqlmodule.c @@ -60,12 +60,15 @@ typedef struct { PyObject *converter; } _mysql_ConnectionObject; +#define check_connection(c) if (!(c->open)) _mysql_Exception(c) +#define result_connection(r) ((_mysql_ConnectionObject *)r->conn) +#define check_result_connection(r) if (!(result_connection(r)->open)) _mysql_Exception(result_connection(r)) + extern PyTypeObject _mysql_ConnectionObject_Type; typedef struct { PyObject_HEAD PyObject *conn; - MYSQL *connection; MYSQL_RES *result; int nfields; int use; @@ -81,6 +84,14 @@ _mysql_Exception(_mysql_ConnectionObject *c) int merr; if (!(t = PyTuple_New(2))) return NULL; + if (!(c->open)) { + e = _mysql_ProgrammingError; + PyTuple_SET_ITEM(t, 0, PyInt_FromLong(-1L)); + PyTuple_SET_ITEM(t, 1, PyString_FromString("connection is closed")); + PyErr_SetObject(e, t); + Py_DECREF(t); + return NULL; + } merr = mysql_errno(&(c->connection)); if (!merr) e = _mysql_InterfaceError; @@ -125,7 +136,6 @@ _mysql_ResultObject_New( _mysql_ResultObject *r; if (!(r = PyObject_NEW(_mysql_ResultObject, &_mysql_ResultObject_Type))) return NULL; - r->connection = &conn->connection; r->conn = (PyObject *) conn; r->converter = NULL; r->use = use; @@ -205,7 +215,6 @@ _mysql_connect( _mysql_ConnectionObject *c = PyObject_NEW(_mysql_ConnectionObject, &_mysql_ConnectionObject_Type); if (c == NULL) return NULL; - c->open = 0; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssisOiiisss:connect", kwlist, &host, &user, &passwd, &db, @@ -261,10 +270,12 @@ _mysql_ConnectionObject_close( PyObject *args) { if (!PyArg_NoArgs(args)) return NULL; - Py_BEGIN_ALLOW_THREADS - mysql_close(&(self->connection)); - Py_END_ALLOW_THREADS - self->open = 0; + if (self->open) { + Py_BEGIN_ALLOW_THREADS + mysql_close(&(self->connection)); + Py_END_ALLOW_THREADS + self->open = 0; + } Py_INCREF(Py_None); return Py_None; } @@ -275,6 +286,7 @@ _mysql_ConnectionObject_affected_rows( PyObject *args) { if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); return PyLong_FromUnsignedLongLong(mysql_affected_rows(&(self->connection))); } @@ -303,6 +315,7 @@ _mysql_ConnectionObject_dump_debug_info( { int err; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS err = mysql_dump_debug_info(&(self->connection)); Py_END_ALLOW_THREADS @@ -317,6 +330,7 @@ _mysql_ConnectionObject_errno( PyObject *args) { if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); return PyInt_FromLong((long)mysql_errno(&(self->connection))); } @@ -326,6 +340,7 @@ _mysql_ConnectionObject_error( PyObject *args) { if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); return PyString_FromString(mysql_error(&(self->connection))); } @@ -353,8 +368,10 @@ _mysql_escape_string( #if MYSQL_VERSION_ID < 32321 len = mysql_escape_string(out, in, size); #else - if (self) + if (self) { + check_connection(self); len = mysql_real_escape_string(&(self->connection), out, in, size); + } else len = mysql_escape_string(out, in, size); #endif @@ -380,10 +397,10 @@ _mysql_string_literal( _mysql_ConnectionObject *self, PyObject *args) { - PyObject *str, *s, *o; + PyObject *str, *s, *o, *d; char *in, *out; int len, size; - if (!PyArg_ParseTuple(args, "O:string_literal", &o)) return NULL; + if (!PyArg_ParseTuple(args, "O|O:string_literal", &o, &d)) return NULL; s = PyObject_Str(o); in = PyString_AsString(s); size = PyString_GET_SIZE(s); @@ -393,8 +410,10 @@ _mysql_string_literal( #if MYSQL_VERSION_ID < 32321 len = mysql_escape_string(out+1, in, size); #else - if (self) + if (self) { + check_connection(self); len = mysql_real_escape_string(&(self->connection), out+1, in, size); + } else len = mysql_escape_string(out+1, in, size); #endif @@ -526,6 +545,7 @@ _mysql_ResultObject_describe( MYSQL_FIELD *fields; unsigned int i, n; if (!PyArg_NoArgs(args)) return NULL; + check_result_connection(self); n = mysql_num_fields(self->result); fields = mysql_fetch_fields(self->result); if (!(d = PyTuple_New(n))) return NULL; @@ -557,6 +577,7 @@ _mysql_ResultObject_field_flags( MYSQL_FIELD *fields; unsigned int i, n; if (!PyArg_NoArgs(args)) return NULL; + check_result_connection(self); n = mysql_num_fields(self->result); fields = mysql_fetch_fields(self->result); if (!(d = PyTuple_New(n))) return NULL; @@ -721,7 +742,7 @@ _mysql__fetch_row( row = mysql_fetch_row(self->result); Py_END_ALLOW_THREADS; } - if (!row && mysql_errno(self->connection)) { + if (!row && mysql_errno(&(((_mysql_ConnectionObject *)(self->conn))->connection))) { _mysql_Exception((_mysql_ConnectionObject *)self->conn); goto error; } @@ -759,6 +780,7 @@ _mysql_ResultObject_fetch_row( if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ii:fetch_row", kwlist, &maxrows, &how)) return NULL; + check_result_connection(self); if (how < 0 || how >= sizeof(row_converters)) { PyErr_SetString(PyExc_ValueError, "how out of range"); return NULL; @@ -809,6 +831,7 @@ _mysql_ConnectionObject_change_user( if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|ss:change_user", kwlist, &user, &pwd, &db)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS r = mysql_change_user(&(self->connection), user, pwd, db); Py_END_ALLOW_THREADS @@ -826,6 +849,7 @@ _mysql_ConnectionObject_character_set_name( { const char *s; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); s = mysql_character_set_name(&(self->connection)); return PyString_FromString(s); } @@ -849,6 +873,7 @@ _mysql_ConnectionObject_get_host_info( PyObject *args) { if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); return PyString_FromString(mysql_get_host_info(&(self->connection))); } @@ -858,6 +883,7 @@ _mysql_ConnectionObject_get_proto_info( PyObject *args) { if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); return PyInt_FromLong((long)mysql_get_proto_info(&(self->connection))); } @@ -867,6 +893,7 @@ _mysql_ConnectionObject_get_server_info( PyObject *args) { if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); return PyString_FromString(mysql_get_server_info(&(self->connection))); } @@ -877,6 +904,7 @@ _mysql_ConnectionObject_info( { char *s; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); s = mysql_info(&(self->connection)); if (s) return PyString_FromString(s); Py_INCREF(Py_None); @@ -890,6 +918,7 @@ _mysql_ConnectionObject_insert_id( { my_ulonglong r; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS r = mysql_insert_id(&(self->connection)); Py_END_ALLOW_THREADS @@ -904,6 +933,7 @@ _mysql_ConnectionObject_kill( unsigned long pid; int r; if (!PyArg_ParseTuple(args, "i:kill", &pid)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS r = mysql_kill(&(self->connection), pid); Py_END_ALLOW_THREADS @@ -921,6 +951,7 @@ _mysql_ConnectionObject_list_dbs( char *wild = NULL; if (!PyArg_ParseTuple(args, "|s:list_dbs", &wild)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS result = mysql_list_dbs(&(self->connection), wild); Py_END_ALLOW_THREADS @@ -938,6 +969,7 @@ _mysql_ConnectionObject_list_fields( char *wild = NULL, *table; if (!PyArg_ParseTuple(args, "s|s:list_fields", &table, &wild)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS result = mysql_list_fields(&(self->connection), table, wild); Py_END_ALLOW_THREADS @@ -954,6 +986,7 @@ _mysql_ConnectionObject_list_processes( MYSQL_RES *result; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS result = mysql_list_processes(&(self->connection)); Py_END_ALLOW_THREADS @@ -971,6 +1004,7 @@ _mysql_ConnectionObject_list_tables( char *wild = NULL; if (!PyArg_ParseTuple(args, "|s:list_tables", &wild)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS result = mysql_list_tables(&(self->connection), wild); Py_END_ALLOW_THREADS @@ -985,6 +1019,7 @@ _mysql_ConnectionObject_field_count( PyObject *args) { if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); #if MYSQL_VERSION_ID < 32224 return PyInt_FromLong((long)mysql_num_fields(&(self->connection))); #else @@ -998,6 +1033,7 @@ _mysql_ResultObject_num_fields( PyObject *args) { if (!PyArg_NoArgs(args)) return NULL; + check_result_connection(self); return PyInt_FromLong((long)mysql_num_fields(self->result)); } @@ -1007,6 +1043,7 @@ _mysql_ResultObject_num_rows( PyObject *args) { if (!PyArg_NoArgs(args)) return NULL; + check_result_connection(self); return PyLong_FromUnsignedLongLong(mysql_num_rows(self->result)); } @@ -1017,6 +1054,7 @@ _mysql_ConnectionObject_ping( { int r; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS r = mysql_ping(&(self->connection)); Py_END_ALLOW_THREADS @@ -1033,6 +1071,7 @@ _mysql_ConnectionObject_query( char *query; int len, r; if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS r = mysql_real_query(&(self->connection), query, len); Py_END_ALLOW_THREADS @@ -1049,6 +1088,7 @@ _mysql_ConnectionObject_select_db( char *db; int r; if (!PyArg_ParseTuple(args, "s:select_db", &db)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS r = mysql_select_db(&(self->connection), db); Py_END_ALLOW_THREADS @@ -1064,6 +1104,7 @@ _mysql_ConnectionObject_shutdown( { int r; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS r = mysql_shutdown(&(self->connection)); Py_END_ALLOW_THREADS @@ -1079,6 +1120,7 @@ _mysql_ConnectionObject_stat( { char *s; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS s = mysql_stat(&(self->connection)); Py_END_ALLOW_THREADS @@ -1094,6 +1136,7 @@ _mysql_ConnectionObject_store_result( MYSQL_RES *result; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS result = mysql_store_result(&(self->connection)); Py_END_ALLOW_THREADS @@ -1112,6 +1155,7 @@ _mysql_ConnectionObject_thread_id( { unsigned long pid; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS pid = mysql_thread_id(&(self->connection)); Py_END_ALLOW_THREADS @@ -1126,6 +1170,7 @@ _mysql_ConnectionObject_use_result( MYSQL_RES *result; if (!PyArg_NoArgs(args)) return NULL; + check_connection(self); Py_BEGIN_ALLOW_THREADS result = mysql_use_result(&(self->connection)); Py_END_ALLOW_THREADS @@ -1141,10 +1186,11 @@ static void _mysql_ConnectionObject_dealloc( _mysql_ConnectionObject *self) { + PyObject *o; + if (self->open) { - Py_BEGIN_ALLOW_THREADS - mysql_close(&(self->connection)); - Py_END_ALLOW_THREADS + o = _mysql_ConnectionObject_close(self, NULL); + Py_XDECREF(o); } Py_XDECREF(self->converter); PyMem_Free((char *) self); @@ -1155,10 +1201,13 @@ _mysql_ConnectionObject_repr( _mysql_ConnectionObject *self) { char buf[300]; - sprintf(buf, "<%s connection to '%.256s' at %lx>", - self->open ? "open" : "closed", - self->connection.host, - (long)self); + if (self->open) + sprintf(buf, "", + self->connection.host, + (long)self); + else + sprintf(buf, "", + (long)self); return PyString_FromString(buf); } @@ -1171,6 +1220,7 @@ _mysql_ResultObject_data_seek( { unsigned int row; if (!PyArg_ParseTuple(args, "i:data_seek", &row)) return NULL; + check_result_connection(self); mysql_data_seek(self->result, row); Py_INCREF(Py_None); return Py_None; @@ -1186,6 +1236,7 @@ _mysql_ResultObject_row_seek( int offset; MYSQL_ROW_OFFSET r; if (!PyArg_ParseTuple(args, "i:row_seek", &offset)) return NULL; + check_result_connection(self); r = mysql_row_tell(self->result); mysql_row_seek(self->result, r+offset); Py_INCREF(Py_None); @@ -1201,6 +1252,7 @@ _mysql_ResultObject_row_tell( { MYSQL_ROW_OFFSET r; if (!PyArg_NoArgs(args)) return NULL; + check_result_connection(self); r = mysql_row_tell(self->result); return PyInt_FromLong(r-self->result->data->data); } @@ -1303,8 +1355,10 @@ _mysql_ConnectionObject_getattr( return PyInt_FromLong((long)(self->open)); if (strcmp(name, "closed") == 0) return PyInt_FromLong((long)!(self->open)); - if (strcmp(name, "server_capabilities") == 0) + if (strcmp(name, "server_capabilities") == 0) { + check_connection(self); return PyInt_FromLong((long)(self->connection.server_capabilities)); + } return PyMember_Get((char *)self, _mysql_ConnectionObject_memberlist, name); }