diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index 39c9bf5b61143d..7a2c711c504ab8 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -20,14 +20,26 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. -import threading -import unittest +import contextlib import sqlite3 as sqlite import sys +import threading +import unittest from test.support.os_helper import TESTFN, unlink +# Helper for tests using TESTFN +@contextlib.contextmanager +def managed_connect(*args, **kwargs): + cx = sqlite.connect(*args, **kwargs) + try: + yield cx + finally: + cx.close() + unlink(TESTFN) + + class ModuleTests(unittest.TestCase): def test_api_level(self): self.assertEqual(sqlite.apilevel, "2.0", @@ -135,6 +147,25 @@ def test_failed_open(self): def test_close(self): self.cx.close() + def test_use_after_close(self): + sql = "select 1" + cu = self.cx.cursor() + res = cu.execute(sql) + self.cx.close() + self.assertRaises(sqlite.ProgrammingError, res.fetchall) + self.assertRaises(sqlite.ProgrammingError, cu.execute, sql) + self.assertRaises(sqlite.ProgrammingError, cu.executemany, sql, []) + self.assertRaises(sqlite.ProgrammingError, cu.executescript, sql) + self.assertRaises(sqlite.ProgrammingError, self.cx.execute, sql) + self.assertRaises(sqlite.ProgrammingError, + self.cx.executemany, sql, []) + self.assertRaises(sqlite.ProgrammingError, self.cx.executescript, sql) + self.assertRaises(sqlite.ProgrammingError, + self.cx.create_function, "t", 1, lambda x: x) + with self.assertRaises(sqlite.ProgrammingError): + with self.cx: + pass + def test_exceptions(self): # Optional DB-API extension. self.assertEqual(self.cx.Warning, sqlite.Warning) @@ -170,26 +201,27 @@ def test_in_transaction_ro(self): with self.assertRaises(AttributeError): self.cx.in_transaction = True +class OpenTests(unittest.TestCase): + _sql = "create table test(id integer)" + def test_open_with_path_like_object(self): """ Checks that we can successfully connect to a database using an object that is PathLike, i.e. has __fspath__(). """ - self.addCleanup(unlink, TESTFN) class Path: def __fspath__(self): return TESTFN path = Path() - with sqlite.connect(path) as cx: - cx.execute('create table test(id integer)') + with managed_connect(path) as cx: + cx.execute(self._sql) def test_open_uri(self): - self.addCleanup(unlink, TESTFN) - with sqlite.connect(TESTFN) as cx: - cx.execute('create table test(id integer)') - with sqlite.connect('file:' + TESTFN, uri=True) as cx: - cx.execute('insert into test(id) values(0)') - with sqlite.connect('file:' + TESTFN + '?mode=ro', uri=True) as cx: - with self.assertRaises(sqlite.OperationalError): - cx.execute('insert into test(id) values(1)') + with managed_connect(TESTFN) as cx: + cx.execute(self._sql) + with managed_connect(f"file:{TESTFN}", uri=True) as cx: + cx.execute(self._sql) + with self.assertRaises(sqlite.OperationalError): + with managed_connect(f"file:{TESTFN}?mode=ro", uri=True) as cx: + cx.execute(self._sql) class CursorTests(unittest.TestCase): @@ -942,6 +974,7 @@ def suite(): CursorTests, ExtensionTests, ModuleTests, + OpenTests, SqliteOnConflictTests, ThreadTests, ] diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index 8c60bdcf5d70aa..520a5b9f11cd40 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -254,11 +254,15 @@ def trace(statement): self.addCleanup(unlink, TESTFN) con1 = sqlite.connect(TESTFN, isolation_level=None) con2 = sqlite.connect(TESTFN) - con1.set_trace_callback(trace) - cur = con1.cursor() - cur.execute(queries[0]) - con2.execute("create table bar(x)") - cur.execute(queries[1]) + try: + con1.set_trace_callback(trace) + cur = con1.cursor() + cur.execute(queries[0]) + con2.execute("create table bar(x)") + cur.execute(queries[1]) + finally: + con1.close() + con2.close() self.assertEqual(traced_statements, queries) diff --git a/Modules/_sqlite/cache.c b/Modules/_sqlite/cache.c index fd4e619f6a0115..8196e3c5783727 100644 --- a/Modules/_sqlite/cache.c +++ b/Modules/_sqlite/cache.c @@ -97,9 +97,6 @@ pysqlite_cache_init(pysqlite_Cache *self, PyObject *args, PyObject *kwargs) } self->factory = Py_NewRef(factory); - - self->decref_factory = 1; - return 0; } @@ -108,9 +105,7 @@ cache_traverse(pysqlite_Cache *self, visitproc visit, void *arg) { Py_VISIT(Py_TYPE(self)); Py_VISIT(self->mapping); - if (self->decref_factory) { - Py_VISIT(self->factory); - } + Py_VISIT(self->factory); pysqlite_Node *node = self->first; while (node) { @@ -124,9 +119,7 @@ static int cache_clear(pysqlite_Cache *self) { Py_CLEAR(self->mapping); - if (self->decref_factory) { - Py_CLEAR(self->factory); - } + Py_CLEAR(self->factory); /* iterate over all nodes and deallocate them */ pysqlite_Node *node = self->first; diff --git a/Modules/_sqlite/cache.h b/Modules/_sqlite/cache.h index 083356f93f9e4c..209c80dcd54ad4 100644 --- a/Modules/_sqlite/cache.h +++ b/Modules/_sqlite/cache.h @@ -52,10 +52,6 @@ typedef struct pysqlite_Node* first; pysqlite_Node* last; - - /* if set, decrement the factory function when the Cache is deallocated. - * this is almost always desirable, but not in the pysqlite context */ - int decref_factory; } pysqlite_Cache; extern PyTypeObject *pysqlite_NodeType; diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 7252ccab10b4bc..42618d5e3216f8 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -149,14 +149,6 @@ pysqlite_connection_init(pysqlite_Connection *self, PyObject *args, return -1; } - /* By default, the Cache class INCREFs the factory in its initializer, and - * decrefs it in its deallocator method. Since this would create a circular - * reference here, we're breaking it by decrementing self, and telling the - * cache class to not decref the factory (self) in its deallocator. - */ - self->statement_cache->decref_factory = 0; - Py_DECREF(self); - self->detect_types = detect_types; self->timeout = timeout; (void)sqlite3_busy_timeout(self->db, (int)(timeout*1000)); @@ -258,6 +250,16 @@ connection_clear(pysqlite_Connection *self) return 0; } +static void +connection_close(pysqlite_Connection *self) +{ + if (self->db) { + int rc = sqlite3_close_v2(self->db); + assert(rc == SQLITE_OK); + self->db = NULL; + } +} + static void connection_dealloc(pysqlite_Connection *self) { @@ -266,9 +268,7 @@ connection_dealloc(pysqlite_Connection *self) tp->tp_clear((PyObject *)self); /* Clean up if user has not called .close() explicitly. */ - if (self->db) { - sqlite3_close_v2(self->db); - } + connection_close(self); tp->tp_free(self); Py_DECREF(tp); @@ -353,23 +353,18 @@ static PyObject * pysqlite_connection_close_impl(pysqlite_Connection *self) /*[clinic end generated code: output=a546a0da212c9b97 input=3d58064bbffaa3d3]*/ { - int rc; - if (!pysqlite_check_thread(self)) { return NULL; } - pysqlite_do_all_statements(self, ACTION_FINALIZE, 1); - if (self->db) { - rc = sqlite3_close_v2(self->db); + /* Free pending statements before closing. This implies also cleaning + * up cursors, as they may have strong refs to statements. */ + Py_CLEAR(self->statement_cache); + Py_CLEAR(self->statements); + Py_CLEAR(self->cursors); - if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); - return NULL; - } else { - self->db = NULL; - } + connection_close(self); } Py_RETURN_NONE; @@ -1820,6 +1815,9 @@ static PyObject * pysqlite_connection_enter_impl(pysqlite_Connection *self) /*[clinic end generated code: output=457b09726d3e9dcd input=127d7a4f17e86d8f]*/ { + if (!pysqlite_check_connection(self)) { + return NULL; + } return Py_NewRef((PyObject *)self); }