diff --git a/pymysql/constants/ER.py b/pymysql/constants/ER.py index ddcc4e90..98729d12 100644 --- a/pymysql/constants/ER.py +++ b/pymysql/constants/ER.py @@ -470,5 +470,8 @@ WRONG_STRING_LENGTH = 1468 ERROR_LAST = 1468 +# MariaDB only +STATEMENT_TIMEOUT = 1969 +QUERY_TIMEOUT = 3024 # https://github.com/PyMySQL/PyMySQL/issues/607 CONSTRAINT_FAILED = 4025 diff --git a/pymysql/tests/base.py b/pymysql/tests/base.py index a87307a5..ff33bc4e 100644 --- a/pymysql/tests/base.py +++ b/pymysql/tests/base.py @@ -49,6 +49,14 @@ def mysql_server_is(self, conn, version_tuple): ) return server_version_tuple >= version_tuple + def get_mysql_vendor(self, conn): + server_version = conn.get_server_info() + + if "MariaDB" in server_version: + return "mariadb" + + return "mysql" + _connections = None @property diff --git a/pymysql/tests/test_SSCursor.py b/pymysql/tests/test_SSCursor.py index d19d3e5d..9cb5bafe 100644 --- a/pymysql/tests/test_SSCursor.py +++ b/pymysql/tests/test_SSCursor.py @@ -1,15 +1,8 @@ -import sys +import pytest -try: - from pymysql.tests import base - import pymysql.cursors - from pymysql.constants import CLIENT, ER -except Exception: - # For local testing from top-level directory, without installing - sys.path.append("../pymysql") - from pymysql.tests import base - import pymysql.cursors - from pymysql.constants import CLIENT, ER +from pymysql.tests import base +import pymysql.cursors +from pymysql.constants import CLIENT, ER class TestSSCursor(base.PyMySQLTestCase): @@ -122,6 +115,92 @@ def test_SSCursor(self): cursor.execute("DROP TABLE IF EXISTS tz_data") cursor.close() + def test_execution_time_limit(self): + # this method is similarly implemented in test_cursor + + conn = self.connect() + + # table creation and filling is SSCursor only as it's not provided by self.setUp() + self.safe_create_table( + conn, + "test", + "create table test (data varchar(10))", + ) + with conn.cursor() as cur: + cur.execute( + "insert into test (data) values " + "('row1'), ('row2'), ('row3'), ('row4'), ('row5')" + ) + conn.commit() + + db_type = self.get_mysql_vendor(conn) + + with conn.cursor(pymysql.cursors.SSCursor) as cur: + # MySQL MAX_EXECUTION_TIME takes ms + # MariaDB max_statement_time takes seconds as int/float, introduced in 10.1 + + # this will sleep 0.01 seconds per row + if db_type == "mysql": + sql = ( + "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test" + ) + else: + sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test" + + cur.execute(sql) + # unlike Cursor, SSCursor returns a list of tuples here + self.assertEqual( + cur.fetchall(), + [ + ("row1", 0), + ("row2", 0), + ("row3", 0), + ("row4", 0), + ("row5", 0), + ], + ) + + if db_type == "mysql": + sql = ( + "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test" + ) + else: + sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test" + cur.execute(sql) + self.assertEqual(cur.fetchone(), ("row1", 0)) + + # this discards the previous unfinished query and raises an + # incomplete unbuffered query warning + with pytest.warns(UserWarning): + cur.execute("SELECT 1") + self.assertEqual(cur.fetchone(), (1,)) + + # SSCursor will not read the EOF packet until we try to read + # another row. Skipping this will raise an incomplete unbuffered + # query warning in the next cur.execute(). + self.assertEqual(cur.fetchone(), None) + + if db_type == "mysql": + sql = "SELECT /*+ MAX_EXECUTION_TIME(1) */ data, sleep(1) FROM test" + else: + sql = "SET STATEMENT max_statement_time=0.001 FOR SELECT data, sleep(1) FROM test" + with pytest.raises(pymysql.err.OperationalError) as cm: + # in an unbuffered cursor the OperationalError may not show up + # until fetching the entire result + cur.execute(sql) + cur.fetchall() + + if db_type == "mysql": + # this constant was only introduced in MySQL 5.7, not sure + # what was returned before, may have been ER_QUERY_INTERRUPTED + self.assertEqual(cm.value.args[0], ER.QUERY_TIMEOUT) + else: + self.assertEqual(cm.value.args[0], ER.STATEMENT_TIMEOUT) + + # connection should still be fine at this point + cur.execute("SELECT 1") + self.assertEqual(cur.fetchone(), (1,)) + def test_warnings(self): con = self.connect() cur = con.cursor(pymysql.cursors.SSCursor) diff --git a/pymysql/tests/test_cursor.py b/pymysql/tests/test_cursor.py index 63ecce02..66d968df 100644 --- a/pymysql/tests/test_cursor.py +++ b/pymysql/tests/test_cursor.py @@ -2,6 +2,8 @@ from pymysql.tests import base import pymysql.cursors +import pytest + class CursorTest(base.PyMySQLTestCase): def setUp(self): @@ -18,6 +20,7 @@ def setUp(self): "insert into test (data) values " "('row1'), ('row2'), ('row3'), ('row4'), ('row5')" ) + conn.commit() cursor.close() self.test_connection = pymysql.connect(**self.databases[0]) self.addCleanup(self.test_connection.close) @@ -129,6 +132,70 @@ def test_executemany(self): finally: cursor.execute("DROP TABLE IF EXISTS percent_test") + def test_execution_time_limit(self): + # this method is similarly implemented in test_SScursor + + conn = self.test_connection + db_type = self.get_mysql_vendor(conn) + + with conn.cursor(pymysql.cursors.Cursor) as cur: + # MySQL MAX_EXECUTION_TIME takes ms + # MariaDB max_statement_time takes seconds as int/float, introduced in 10.1 + + # this will sleep 0.01 seconds per row + if db_type == "mysql": + sql = ( + "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test" + ) + else: + sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test" + + cur.execute(sql) + # unlike SSCursor, Cursor returns a tuple of tuples here + self.assertEqual( + cur.fetchall(), + ( + ("row1", 0), + ("row2", 0), + ("row3", 0), + ("row4", 0), + ("row5", 0), + ), + ) + + if db_type == "mysql": + sql = ( + "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test" + ) + else: + sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test" + cur.execute(sql) + self.assertEqual(cur.fetchone(), ("row1", 0)) + + # this discards the previous unfinished query + cur.execute("SELECT 1") + self.assertEqual(cur.fetchone(), (1,)) + + if db_type == "mysql": + sql = "SELECT /*+ MAX_EXECUTION_TIME(1) */ data, sleep(1) FROM test" + else: + sql = "SET STATEMENT max_statement_time=0.001 FOR SELECT data, sleep(1) FROM test" + with pytest.raises(pymysql.err.OperationalError) as cm: + # in a buffered cursor this should reliably raise an + # OperationalError + cur.execute(sql) + + if db_type == "mysql": + # this constant was only introduced in MySQL 5.7, not sure + # what was returned before, may have been ER_QUERY_INTERRUPTED + self.assertEqual(cm.value.args[0], ER.QUERY_TIMEOUT) + else: + self.assertEqual(cm.value.args[0], ER.STATEMENT_TIMEOUT) + + # connection should still be fine at this point + cur.execute("SELECT 1") + self.assertEqual(cur.fetchone(), (1,)) + def test_warnings(self): con = self.connect() cur = con.cursor()