diff --git a/pymysql/cursors.py b/pymysql/cursors.py index e57fba76..d8a93c78 100644 --- a/pymysql/cursors.py +++ b/pymysql/cursors.py @@ -342,7 +342,13 @@ def _do_get_result(self): self._rows = result.rows def __iter__(self): - return iter(self.fetchone, None) + return self + + def __next__(self): + row = self.fetchone() + if row is None: + raise StopIteration + return row Warning = err.Warning Error = err.Error @@ -459,9 +465,6 @@ def fetchall_unbuffered(self): """ return iter(self.fetchone, None) - def __iter__(self): - return self.fetchall_unbuffered() - def fetchmany(self, size=None): """Fetch many.""" self._check_executed() diff --git a/pymysql/tests/test_cursor.py b/pymysql/tests/test_cursor.py index 66d968df..16d297f6 100644 --- a/pymysql/tests/test_cursor.py +++ b/pymysql/tests/test_cursor.py @@ -25,6 +25,14 @@ def setUp(self): self.test_connection = pymysql.connect(**self.databases[0]) self.addCleanup(self.test_connection.close) + def test_cursor_is_iterator(self): + """Test that the cursor is an iterator""" + conn = self.test_connection + cursor = conn.cursor() + cursor.execute("select * from test") + self.assertEqual(cursor.__iter__(), cursor) + self.assertEqual(cursor.__next__(), ("row1",)) + def test_cleanup_rows_unbuffered(self): conn = self.test_connection cursor = conn.cursor(pymysql.cursors.SSCursor)