Warning propagation improvements

This commit is contained in:
Vilnis Termanis
2016-07-26 18:43:35 +01:00
parent 8c219d9261
commit ad935e4ba8
2 changed files with 44 additions and 14 deletions

View File

@ -73,8 +73,7 @@ class BaseCursor(object):
self.messages = [] self.messages = []
self.errorhandler = connection.errorhandler self.errorhandler = connection.errorhandler
self._result = None self._result = None
self._warnings = 0 self._warnings = None
self._info = None
self.rownumber = None self.rownumber = None
def close(self): def close(self):
@ -128,29 +127,37 @@ class BaseCursor(object):
def _warning_check(self): def _warning_check(self):
from warnings import warn from warnings import warn
db = self._get_db()
# None => warnings not interrogated for current query yet
# 0 => no warnings exists or have been handled already for this query
if self._warnings is None:
self._warnings = db.warning_count()
if self._warnings: if self._warnings:
# Only propagate warnings for current query once
warning_count = self._warnings
self._warnings = 0
# When there is next result, fetching warnings cause "command # When there is next result, fetching warnings cause "command
# out of sync" error. # out of sync" error.
if self._result and self._result.has_next: if self._result and self._result.has_next:
msg = "There are %d MySQL warnings." % (self._warnings,) msg = "There are %d MySQL warnings." % (warning_count,)
self.messages.append(msg) self.messages.append(msg)
warn(msg, self.Warning, 3) warn(self.Warning(0, msg), stacklevel=3)
return return
warnings = self._get_db().show_warnings() warnings = db.show_warnings()
if warnings: if warnings:
# This is done in two loops in case # This is done in two loops in case
# Warnings are set to raise exceptions. # Warnings are set to raise exceptions.
for w in warnings: for w in warnings:
self.messages.append((self.Warning, w)) self.messages.append((self.Warning, w))
for w in warnings: for w in warnings:
msg = w[-1] warn(self.Warning(*w[1:3]), stacklevel=3)
if not PY2 and isinstance(msg, bytes): else:
msg = msg.decode() info = db.info()
warn(msg, self.Warning, 3) if info:
elif self._info: self.messages.append((self.Warning, info))
self.messages.append((self.Warning, self._info)) warn(self.Warning(0, info), stacklevel=3)
warn(self._info, self.Warning, 3)
def nextset(self): def nextset(self):
"""Advance to the next result set. """Advance to the next result set.
@ -180,8 +187,7 @@ class BaseCursor(object):
self.description = self._result and self._result.describe() or None self.description = self._result and self._result.describe() or None
self.description_flags = self._result and self._result.field_flags() or None self.description_flags = self._result and self._result.field_flags() or None
self.lastrowid = db.insert_id() self.lastrowid = db.insert_id()
self._warnings = db.warning_count() self._warnings = None
self._info = db.info()
def setinputsizes(self, *args): def setinputsizes(self, *args):
"""Does nothing, required by DB API.""" """Does nothing, required by DB API."""

View File

@ -3,6 +3,8 @@ import capabilities
from datetime import timedelta from datetime import timedelta
import unittest import unittest
import MySQLdb import MySQLdb
from MySQLdb.compat import unicode
from MySQLdb import cursors
import warnings import warnings
@ -155,6 +157,28 @@ class test_MySQLdb(capabilities.DatabaseTest):
return return
self.fail("Should raise ProgrammingError") self.fail("Should raise ProgrammingError")
def test_warning_propagation(self):
with warnings.catch_warnings():
# Ignore all warnings other than MySQLdb generated ones
warnings.simplefilter("ignore")
warnings.simplefilter("error", category=MySQLdb.Warning)
# verify for both buffered and unbuffered cursor types
for cursor_class in (cursors.Cursor, cursors.SSCursor):
c = self.connection.cursor(cursor_class)
try:
c.execute("SELECT CAST('124b' AS SIGNED)")
c.fetchall()
except MySQLdb.Warning as e:
# Warnings should have errorcode and string message, just like exceptions
self.assertEqual(len(e.args), 2)
self.assertEqual(e.args[0], 1292)
self.assertTrue(isinstance(e.args[1], unicode))
else:
self.fail("Should raise Warning")
finally:
c.close()
if __name__ == '__main__': if __name__ == '__main__':
if test_MySQLdb.leak_test: if test_MySQLdb.leak_test: