diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a429244..a34925fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changes +## v1.1.0 + +Release date: TBD + +* Exposed `Cursor.warning_count` to check for warnings without additional query (#1056) + ## v1.0.3 Release date: TBD @@ -7,6 +13,7 @@ Release date: TBD * Dropped support of end of life MySQL version 5.6 * Dropped support of end of life MariaDB versions below 10.3 * Dropped support of end of life Python version 3.6 +* Exposed `Cursor.warning_count` to check for warnings without additional query (#1056) ## v1.0.2 diff --git a/pymysql/cursors.py b/pymysql/cursors.py index 2b5ccca9..e6206771 100644 --- a/pymysql/cursors.py +++ b/pymysql/cursors.py @@ -32,6 +32,7 @@ class Cursor: def __init__(self, connection): self.connection = connection + self.warning_count = 0 self.description = None self.rownumber = 0 self.rowcount = -1 @@ -331,6 +332,7 @@ def _clear_result(self): self._result = None self.rowcount = 0 + self.warning_count = 0 self.description = None self.lastrowid = None self._rows = None @@ -341,6 +343,7 @@ def _do_get_result(self): self._result = result = conn._result self.rowcount = result.affected_rows + self.warning_count = result.warning_count self.description = result.description self.lastrowid = result.insert_id self._rows = result.rows @@ -442,6 +445,7 @@ def fetchone(self): self._check_executed() row = self.read_next() if row is None: + self.warning_count = self._result.warning_count return None self.rownumber += 1 return row @@ -475,6 +479,7 @@ def fetchmany(self, size=None): for i in range(size): row = self.read_next() if row is None: + self.warning_count = self._result.warning_count break rows.append(row) self.rownumber += 1 diff --git a/pymysql/tests/test_SSCursor.py b/pymysql/tests/test_SSCursor.py index a68a7769..d19d3e5d 100644 --- a/pymysql/tests/test_SSCursor.py +++ b/pymysql/tests/test_SSCursor.py @@ -3,13 +3,13 @@ try: from pymysql.tests import base import pymysql.cursors - from pymysql.constants import CLIENT + from pymysql.constants import CLIENT, ER except Exception: # For local testing from top-level directory, without installing sys.path.append("../pymysql") from pymysql.tests import base import pymysql.cursors - from pymysql.constants import CLIENT + from pymysql.constants import CLIENT, ER class TestSSCursor(base.PyMySQLTestCase): @@ -122,6 +122,35 @@ def test_SSCursor(self): cursor.execute("DROP TABLE IF EXISTS tz_data") cursor.close() + def test_warnings(self): + con = self.connect() + cur = con.cursor(pymysql.cursors.SSCursor) + cur.execute("DROP TABLE IF EXISTS `no_exists_table`") + self.assertEqual(cur.warning_count, 1) + + cur.execute("SHOW WARNINGS") + w = cur.fetchone() + self.assertEqual(w[1], ER.BAD_TABLE_ERROR) + self.assertIn( + "no_exists_table", + w[2], + ) + + # ensure unbuffered result is finished + self.assertIsNone(cur.fetchone()) + + cur.execute("SELECT 1") + self.assertEqual(cur.fetchone(), (1,)) + self.assertIsNone(cur.fetchone()) + + self.assertEqual(cur.warning_count, 0) + + cur.execute("SELECT CAST('abc' AS SIGNED)") + # this ensures fully retrieving the unbuffered result + rows = cur.fetchmany(2) + self.assertEqual(len(rows), 1) + self.assertEqual(cur.warning_count, 1) + __all__ = ["TestSSCursor"] diff --git a/pymysql/tests/test_cursor.py b/pymysql/tests/test_cursor.py index 783caf88..63ecce02 100644 --- a/pymysql/tests/test_cursor.py +++ b/pymysql/tests/test_cursor.py @@ -1,5 +1,4 @@ -import warnings - +from pymysql.constants import ER from pymysql.tests import base import pymysql.cursors @@ -129,3 +128,20 @@ def test_executemany(self): ) finally: cursor.execute("DROP TABLE IF EXISTS percent_test") + + def test_warnings(self): + con = self.connect() + cur = con.cursor() + cur.execute("DROP TABLE IF EXISTS `no_exists_table`") + self.assertEqual(cur.warning_count, 1) + + cur.execute("SHOW WARNINGS") + w = cur.fetchone() + self.assertEqual(w[1], ER.BAD_TABLE_ERROR) + self.assertIn( + "no_exists_table", + w[2], + ) + + cur.execute("SELECT 1") + self.assertEqual(cur.warning_count, 0) diff --git a/pymysql/tests/test_load_local.py b/pymysql/tests/test_load_local.py index b1b8128e..194c5be9 100644 --- a/pymysql/tests/test_load_local.py +++ b/pymysql/tests/test_load_local.py @@ -1,4 +1,5 @@ from pymysql import cursors, OperationalError, Warning +from pymysql.constants import ER from pymysql.tests import base import os @@ -63,6 +64,37 @@ def test_unbuffered_load_file(self): c = conn.cursor() c.execute("DROP TABLE test_load_local") + def test_load_warnings(self): + """Test load local infile produces the appropriate warnings""" + conn = self.connect() + c = conn.cursor() + c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") + filename = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "data", + "load_local_warn_data.txt", + ) + try: + c.execute( + ( + "LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + + "test_load_local FIELDS TERMINATED BY ','" + ).format(filename) + ) + self.assertEqual(1, c.warning_count) + + c.execute("SHOW WARNINGS") + w = c.fetchone() + + self.assertEqual(ER.TRUNCATED_WRONG_VALUE_FOR_FIELD, w[1]) + self.assertIn( + "incorrect integer value", + w[2].lower(), + ) + finally: + c.execute("DROP TABLE test_load_local") + c.close() + if __name__ == "__main__": import unittest