Add missing checks for connection before calling mysql APIs (#272)

Fixes #270
This commit is contained in:
INADA Naoki
2018-10-22 21:26:57 +09:00
committed by GitHub
parent 54c69436f4
commit 4a4978d6b1

View File

@ -69,9 +69,14 @@ typedef struct {
PyObject *converter; PyObject *converter;
} _mysql_ConnectionObject; } _mysql_ConnectionObject;
#define check_connection(c) if (!(c->open)) return _mysql_Exception(c) #define check_connection(c, func) \
if (!(c->open)) { \
PyErr_SetString(_mysql_ProgrammingError, func "() is called for closed connection"); \
return NULL; \
};
#define result_connection(r) ((_mysql_ConnectionObject *)r->conn) #define result_connection(r) ((_mysql_ConnectionObject *)r->conn)
#define check_result_connection(r) check_connection(result_connection(r)) #define check_result_connection(r, func) check_connection(result_connection(r), func)
extern PyTypeObject _mysql_ConnectionObject_Type; extern PyTypeObject _mysql_ConnectionObject_Type;
@ -750,6 +755,7 @@ static PyObject *
_mysql_ConnectionObject_fileno( _mysql_ConnectionObject_fileno(
_mysql_ConnectionObject *self) _mysql_ConnectionObject *self)
{ {
check_connection(self, "fileno");
return PyInt_FromLong(self->connection.net.fd); return PyInt_FromLong(self->connection.net.fd);
} }
@ -761,16 +767,11 @@ _mysql_ConnectionObject_close(
_mysql_ConnectionObject *self, _mysql_ConnectionObject *self,
PyObject *noargs) PyObject *noargs)
{ {
if (self->open) { check_connection(self, "close");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
mysql_close(&(self->connection)); mysql_close(&(self->connection));
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
self->open = 0; self->open = 0;
} else {
PyErr_SetString(_mysql_ProgrammingError,
"closing a closed connection");
return NULL;
}
_mysql_ConnectionObject_clear(self); _mysql_ConnectionObject_clear(self);
Py_RETURN_NONE; Py_RETURN_NONE;
} }
@ -786,7 +787,7 @@ _mysql_ConnectionObject_affected_rows(
PyObject *noargs) PyObject *noargs)
{ {
my_ulonglong ret; my_ulonglong ret;
check_connection(self); check_connection(self, "affected_rows");
ret = mysql_affected_rows(&(self->connection)); ret = mysql_affected_rows(&(self->connection));
if (ret == (my_ulonglong)-1) if (ret == (my_ulonglong)-1)
return PyInt_FromLong(-1); return PyInt_FromLong(-1);
@ -823,7 +824,7 @@ _mysql_ConnectionObject_dump_debug_info(
PyObject *noargs) PyObject *noargs)
{ {
int err; int err;
check_connection(self); check_connection(self, "dump_debug_info");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
err = mysql_dump_debug_info(&(self->connection)); err = mysql_dump_debug_info(&(self->connection));
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -842,6 +843,7 @@ _mysql_ConnectionObject_autocommit(
{ {
int flag, err; int flag, err;
if (!PyArg_ParseTuple(args, "i", &flag)) return NULL; if (!PyArg_ParseTuple(args, "i", &flag)) return NULL;
check_connection(self, "autocommit");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
err = mysql_autocommit(&(self->connection), flag); err = mysql_autocommit(&(self->connection), flag);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -858,6 +860,7 @@ _mysql_ConnectionObject_get_autocommit(
_mysql_ConnectionObject *self, _mysql_ConnectionObject *self,
PyObject *args) PyObject *args)
{ {
check_connection(self, "get_autocommit");
if (self->connection.server_status & SERVER_STATUS_AUTOCOMMIT) { if (self->connection.server_status & SERVER_STATUS_AUTOCOMMIT) {
Py_RETURN_TRUE; Py_RETURN_TRUE;
} }
@ -873,6 +876,7 @@ _mysql_ConnectionObject_commit(
PyObject *noargs) PyObject *noargs)
{ {
int err; int err;
check_connection(self, "commit");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
err = mysql_commit(&(self->connection)); err = mysql_commit(&(self->connection));
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -890,12 +894,12 @@ _mysql_ConnectionObject_rollback(
PyObject *noargs) PyObject *noargs)
{ {
int err; int err;
check_connection(self, "rollback");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
err = mysql_rollback(&(self->connection)); err = mysql_rollback(&(self->connection));
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
if (err) return _mysql_Exception(self); if (err) return _mysql_Exception(self);
Py_INCREF(Py_None); Py_RETURN_NONE;
return Py_None;
} }
static char _mysql_ConnectionObject_next_result__doc__[] = static char _mysql_ConnectionObject_next_result__doc__[] =
@ -917,6 +921,7 @@ _mysql_ConnectionObject_next_result(
PyObject *noargs) PyObject *noargs)
{ {
int err; int err;
check_connection(self, "next_result");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
err = mysql_next_result(&(self->connection)); err = mysql_next_result(&(self->connection));
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -939,6 +944,7 @@ _mysql_ConnectionObject_set_server_option(
int err, flags=0; int err, flags=0;
if (!PyArg_ParseTuple(args, "i", &flags)) if (!PyArg_ParseTuple(args, "i", &flags))
return NULL; return NULL;
check_connection(self, "set_server_option");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
err = mysql_set_server_option(&(self->connection), flags); err = mysql_set_server_option(&(self->connection), flags);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -963,6 +969,7 @@ _mysql_ConnectionObject_sqlstate(
_mysql_ConnectionObject *self, _mysql_ConnectionObject *self,
PyObject *noargs) PyObject *noargs)
{ {
check_connection(self, "sqlstate");
return PyString_FromString(mysql_sqlstate(&(self->connection))); return PyString_FromString(mysql_sqlstate(&(self->connection)));
} }
@ -977,6 +984,7 @@ _mysql_ConnectionObject_warning_count(
_mysql_ConnectionObject *self, _mysql_ConnectionObject *self,
PyObject *noargs) PyObject *noargs)
{ {
check_connection(self, "warning_count");
return PyInt_FromLong(mysql_warning_count(&(self->connection))); return PyInt_FromLong(mysql_warning_count(&(self->connection)));
} }
@ -991,7 +999,7 @@ _mysql_ConnectionObject_errno(
_mysql_ConnectionObject *self, _mysql_ConnectionObject *self,
PyObject *noargs) PyObject *noargs)
{ {
check_connection(self); check_connection(self, "errno");
return PyInt_FromLong((long)mysql_errno(&(self->connection))); return PyInt_FromLong((long)mysql_errno(&(self->connection)));
} }
@ -1006,7 +1014,7 @@ _mysql_ConnectionObject_error(
_mysql_ConnectionObject *self, _mysql_ConnectionObject *self,
PyObject *noargs) PyObject *noargs)
{ {
check_connection(self); check_connection(self, "error");
return PyString_FromString(mysql_error(&(self->connection))); return PyString_FromString(mysql_error(&(self->connection)));
} }
@ -1249,7 +1257,7 @@ _mysql_ResultObject_describe(
PyObject *d; PyObject *d;
MYSQL_FIELD *fields; MYSQL_FIELD *fields;
unsigned int i, n; unsigned int i, n;
check_result_connection(self); check_result_connection(self, "describe");
n = mysql_num_fields(self->result); n = mysql_num_fields(self->result);
fields = mysql_fetch_fields(self->result); fields = mysql_fetch_fields(self->result);
if (!(d = PyTuple_New(n))) return NULL; if (!(d = PyTuple_New(n))) return NULL;
@ -1284,7 +1292,7 @@ _mysql_ResultObject_field_flags(
PyObject *d; PyObject *d;
MYSQL_FIELD *fields; MYSQL_FIELD *fields;
unsigned int i, n; unsigned int i, n;
check_result_connection(self); check_result_connection(self, "field_flags");
n = mysql_num_fields(self->result); n = mysql_num_fields(self->result);
fields = mysql_fetch_fields(self->result); fields = mysql_fetch_fields(self->result);
if (!(d = PyTuple_New(n))) return NULL; if (!(d = PyTuple_New(n))) return NULL;
@ -1523,7 +1531,7 @@ _mysql_ResultObject_fetch_row(
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ii:fetch_row", kwlist, if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ii:fetch_row", kwlist,
&maxrows, &how)) &maxrows, &how))
return NULL; return NULL;
check_result_connection(self); check_result_connection(self, "fetch_row");
if (how >= (int)sizeof(row_converters)) { if (how >= (int)sizeof(row_converters)) {
PyErr_SetString(PyExc_ValueError, "how out of range"); PyErr_SetString(PyExc_ValueError, "how out of range");
return NULL; return NULL;
@ -1592,7 +1600,7 @@ _mysql_ConnectionObject_change_user(
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|ss:change_user", if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|ss:change_user",
kwlist, &user, &pwd, &db)) kwlist, &user, &pwd, &db))
return NULL; return NULL;
check_connection(self); check_connection(self, "change_user");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
r = mysql_change_user(&(self->connection), user, pwd, db); r = mysql_change_user(&(self->connection), user, pwd, db);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -1612,7 +1620,7 @@ _mysql_ConnectionObject_character_set_name(
PyObject *noargs) PyObject *noargs)
{ {
const char *s; const char *s;
check_connection(self); check_connection(self, "character_set_name");
s = mysql_character_set_name(&(self->connection)); s = mysql_character_set_name(&(self->connection));
return PyString_FromString(s); return PyString_FromString(s);
} }
@ -1630,7 +1638,7 @@ _mysql_ConnectionObject_set_character_set(
const char *s; const char *s;
int err; int err;
if (!PyArg_ParseTuple(args, "s", &s)) return NULL; if (!PyArg_ParseTuple(args, "s", &s)) return NULL;
check_connection(self); check_connection(self, "set_character_set");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
err = mysql_set_character_set(&(self->connection), s); err = mysql_set_character_set(&(self->connection), s);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -1669,7 +1677,7 @@ _mysql_ConnectionObject_get_character_set_info(
PyObject *result; PyObject *result;
MY_CHARSET_INFO cs; MY_CHARSET_INFO cs;
check_connection(self); check_connection(self, "get_character_set_info");
mysql_get_character_set_info(&(self->connection), &cs); mysql_get_character_set_info(&(self->connection), &cs);
if (!(result = PyDict_New())) return NULL; if (!(result = PyDict_New())) return NULL;
if (cs.csname) if (cs.csname)
@ -1701,7 +1709,7 @@ _mysql_ConnectionObject_get_native_connection(
PyObject *noargs) PyObject *noargs)
{ {
PyObject *result; PyObject *result;
check_connection(self); check_connection(self, "_get_native_connection");
result = PyCapsule_New(&(self->connection), result = PyCapsule_New(&(self->connection),
"_mysql.connection.native_connection", NULL); "_mysql.connection.native_connection", NULL);
return result; return result;
@ -1730,7 +1738,7 @@ _mysql_ConnectionObject_get_host_info(
_mysql_ConnectionObject *self, _mysql_ConnectionObject *self,
PyObject *noargs) PyObject *noargs)
{ {
check_connection(self); check_connection(self, "get_host_info");
return PyString_FromString(mysql_get_host_info(&(self->connection))); return PyString_FromString(mysql_get_host_info(&(self->connection)));
} }
@ -1744,7 +1752,7 @@ _mysql_ConnectionObject_get_proto_info(
_mysql_ConnectionObject *self, _mysql_ConnectionObject *self,
PyObject *noargs) PyObject *noargs)
{ {
check_connection(self); check_connection(self, "get_proto_info");
return PyInt_FromLong((long)mysql_get_proto_info(&(self->connection))); return PyInt_FromLong((long)mysql_get_proto_info(&(self->connection)));
} }
@ -1758,7 +1766,7 @@ _mysql_ConnectionObject_get_server_info(
_mysql_ConnectionObject *self, _mysql_ConnectionObject *self,
PyObject *noargs) PyObject *noargs)
{ {
check_connection(self); check_connection(self, "get_server_info");
return PyString_FromString(mysql_get_server_info(&(self->connection))); return PyString_FromString(mysql_get_server_info(&(self->connection)));
} }
@ -1774,7 +1782,7 @@ _mysql_ConnectionObject_info(
PyObject *noargs) PyObject *noargs)
{ {
const char *s; const char *s;
check_connection(self); check_connection(self, "info");
s = mysql_info(&(self->connection)); s = mysql_info(&(self->connection));
if (s) return PyString_FromString(s); if (s) return PyString_FromString(s);
Py_INCREF(Py_None); Py_INCREF(Py_None);
@ -1808,7 +1816,7 @@ _mysql_ConnectionObject_insert_id(
PyObject *noargs) PyObject *noargs)
{ {
my_ulonglong r; my_ulonglong r;
check_connection(self); check_connection(self, "insert_id");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
r = mysql_insert_id(&(self->connection)); r = mysql_insert_id(&(self->connection));
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -1827,7 +1835,7 @@ _mysql_ConnectionObject_kill(
unsigned long pid; unsigned long pid;
int r; int r;
if (!PyArg_ParseTuple(args, "k:kill", &pid)) return NULL; if (!PyArg_ParseTuple(args, "k:kill", &pid)) return NULL;
check_connection(self); check_connection(self, "kill");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
r = mysql_kill(&(self->connection), pid); r = mysql_kill(&(self->connection), pid);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -1847,7 +1855,7 @@ _mysql_ConnectionObject_field_count(
_mysql_ConnectionObject *self, _mysql_ConnectionObject *self,
PyObject *noargs) PyObject *noargs)
{ {
check_connection(self); check_connection(self, "field_count");
return PyInt_FromLong((long)mysql_field_count(&(self->connection))); return PyInt_FromLong((long)mysql_field_count(&(self->connection)));
} }
@ -1859,7 +1867,7 @@ _mysql_ResultObject_num_fields(
_mysql_ResultObject *self, _mysql_ResultObject *self,
PyObject *noargs) PyObject *noargs)
{ {
check_result_connection(self); check_result_connection(self, "num_fields");
return PyInt_FromLong((long)mysql_num_fields(self->result)); return PyInt_FromLong((long)mysql_num_fields(self->result));
} }
@ -1874,7 +1882,7 @@ _mysql_ResultObject_num_rows(
_mysql_ResultObject *self, _mysql_ResultObject *self,
PyObject *noargs) PyObject *noargs)
{ {
check_result_connection(self); check_result_connection(self, "num_rows");
return PyLong_FromUnsignedLongLong(mysql_num_rows(self->result)); return PyLong_FromUnsignedLongLong(mysql_num_rows(self->result));
} }
@ -1904,7 +1912,7 @@ _mysql_ConnectionObject_ping(
{ {
int r, reconnect = -1; int r, reconnect = -1;
if (!PyArg_ParseTuple(args, "|I", &reconnect)) return NULL; if (!PyArg_ParseTuple(args, "|I", &reconnect)) return NULL;
check_connection(self); check_connection(self, "ping");
if (reconnect != -1) { if (reconnect != -1) {
my_bool recon = (my_bool)reconnect; my_bool recon = (my_bool)reconnect;
mysql_options(&self->connection, MYSQL_OPT_RECONNECT, &recon); mysql_options(&self->connection, MYSQL_OPT_RECONNECT, &recon);
@ -1931,7 +1939,7 @@ _mysql_ConnectionObject_query(
char *query; char *query;
int len, r; int len, r;
if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL; if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL;
check_connection(self); check_connection(self, "query");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
r = mysql_real_query(&(self->connection), query, len); r = mysql_real_query(&(self->connection), query, len);
@ -1955,7 +1963,7 @@ _mysql_ConnectionObject_send_query(
int len, r; int len, r;
MYSQL *mysql = &(self->connection); MYSQL *mysql = &(self->connection);
if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL; if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL;
check_connection(self); check_connection(self, "send_query");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
r = mysql_send_query(mysql, query, len); r = mysql_send_query(mysql, query, len);
@ -1976,7 +1984,7 @@ _mysql_ConnectionObject_read_query_result(
{ {
int r; int r;
MYSQL *mysql = &(self->connection); MYSQL *mysql = &(self->connection);
check_connection(self); check_connection(self, "reqd_query_result");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
r = (int)mysql_read_query_result(mysql); r = (int)mysql_read_query_result(mysql);
@ -2006,7 +2014,7 @@ _mysql_ConnectionObject_select_db(
char *db; char *db;
int r; int r;
if (!PyArg_ParseTuple(args, "s:select_db", &db)) return NULL; if (!PyArg_ParseTuple(args, "s:select_db", &db)) return NULL;
check_connection(self); check_connection(self, "select_db");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
r = mysql_select_db(&(self->connection), db); r = mysql_select_db(&(self->connection), db);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -2026,7 +2034,7 @@ _mysql_ConnectionObject_shutdown(
PyObject *noargs) PyObject *noargs)
{ {
int r; int r;
check_connection(self); check_connection(self, "shutdown");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
r = mysql_shutdown(&(self->connection), SHUTDOWN_DEFAULT); r = mysql_shutdown(&(self->connection), SHUTDOWN_DEFAULT);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -2048,7 +2056,7 @@ _mysql_ConnectionObject_stat(
PyObject *noargs) PyObject *noargs)
{ {
const char *s; const char *s;
check_connection(self); check_connection(self, "stat");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
s = mysql_stat(&(self->connection)); s = mysql_stat(&(self->connection));
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -2070,7 +2078,7 @@ _mysql_ConnectionObject_store_result(
PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL; PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL;
_mysql_ResultObject *r=NULL; _mysql_ResultObject *r=NULL;
check_connection(self); check_connection(self, "store_result");
arglist = Py_BuildValue("(OiO)", self, 0, self->converter); arglist = Py_BuildValue("(OiO)", self, 0, self->converter);
if (!arglist) goto error; if (!arglist) goto error;
kwarglist = PyDict_New(); kwarglist = PyDict_New();
@ -2108,7 +2116,7 @@ _mysql_ConnectionObject_thread_id(
PyObject *noargs) PyObject *noargs)
{ {
unsigned long pid; unsigned long pid;
check_connection(self); check_connection(self, "thread_id");
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
pid = mysql_thread_id(&(self->connection)); pid = mysql_thread_id(&(self->connection));
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
@ -2129,7 +2137,7 @@ _mysql_ConnectionObject_use_result(
PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL; PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL;
_mysql_ResultObject *r=NULL; _mysql_ResultObject *r=NULL;
check_connection(self); check_connection(self, "use_result");
arglist = Py_BuildValue("(OiO)", self, 1, self->converter); arglist = Py_BuildValue("(OiO)", self, 1, self->converter);
if (!arglist) return NULL; if (!arglist) return NULL;
kwarglist = PyDict_New(); kwarglist = PyDict_New();
@ -2187,7 +2195,7 @@ _mysql_ResultObject_data_seek(
{ {
unsigned int row; unsigned int row;
if (!PyArg_ParseTuple(args, "i:data_seek", &row)) return NULL; if (!PyArg_ParseTuple(args, "i:data_seek", &row)) return NULL;
check_result_connection(self); check_result_connection(self, "data_seek");
mysql_data_seek(self->result, row); mysql_data_seek(self->result, row);
Py_INCREF(Py_None); Py_INCREF(Py_None);
return Py_None; return Py_None;