|
20 | 20 | # misrepresented as being the original software. |
21 | 21 | # 3. This notice may not be removed or altered from any source distribution. |
22 | 22 |
|
23 | | -import unittest |
| 23 | +import contextlib |
24 | 24 | import sqlite3 as sqlite |
| 25 | +import unittest |
25 | 26 |
|
26 | 27 | from test.support.os_helper import TESTFN, unlink |
| 28 | + |
| 29 | +from test.test_sqlite3.test_dbapi import memory_database, cx_limit |
27 | 30 | from test.test_sqlite3.test_userfunctions import with_tracebacks |
28 | 31 |
|
| 32 | + |
29 | 33 | class CollationTests(unittest.TestCase): |
30 | 34 | def test_create_collation_not_string(self): |
31 | 35 | con = sqlite.connect(":memory:") |
@@ -224,6 +228,16 @@ def bad_progress(): |
224 | 228 |
|
225 | 229 |
|
226 | 230 | class TraceCallbackTests(unittest.TestCase): |
| 231 | + @contextlib.contextmanager |
| 232 | + def check_stmt_trace(self, cx, expected): |
| 233 | + try: |
| 234 | + traced = [] |
| 235 | + cx.set_trace_callback(lambda stmt: traced.append(stmt)) |
| 236 | + yield |
| 237 | + finally: |
| 238 | + self.assertEqual(traced, expected) |
| 239 | + cx.set_trace_callback(None) |
| 240 | + |
227 | 241 | def test_trace_callback_used(self): |
228 | 242 | """ |
229 | 243 | Test that the trace callback is invoked once it is set. |
@@ -289,6 +303,51 @@ def trace(statement): |
289 | 303 | con2.close() |
290 | 304 | self.assertEqual(traced_statements, queries) |
291 | 305 |
|
| 306 | + @unittest.skipIf(sqlite.sqlite_version_info < (3, 14, 0), |
| 307 | + "Requires SQLite 3.14.0 or newer") |
| 308 | + def test_trace_expanded_sql(self): |
| 309 | + expected = [ |
| 310 | + "create table t(t)", |
| 311 | + "BEGIN ", |
| 312 | + "insert into t values(0)", |
| 313 | + "insert into t values(1)", |
| 314 | + "insert into t values(2)", |
| 315 | + "COMMIT", |
| 316 | + ] |
| 317 | + with memory_database() as cx, self.check_stmt_trace(cx, expected): |
| 318 | + with cx: |
| 319 | + cx.execute("create table t(t)") |
| 320 | + cx.executemany("insert into t values(?)", ((v,) for v in range(3))) |
| 321 | + |
| 322 | + @with_tracebacks( |
| 323 | + sqlite.DataError, |
| 324 | + regex="Expanded SQL string exceeds the maximum string length" |
| 325 | + ) |
| 326 | + def test_trace_too_much_expanded_sql(self): |
| 327 | + # If the expanded string is too large, we'll fall back to the |
| 328 | + # unexpanded SQL statement. The resulting string length is limited by |
| 329 | + # SQLITE_LIMIT_LENGTH. |
| 330 | + template = "select 'b' as \"a\" from sqlite_master where \"a\"=" |
| 331 | + category = sqlite.SQLITE_LIMIT_LENGTH |
| 332 | + with memory_database() as cx, cx_limit(cx, category=category) as lim: |
| 333 | + nextra = lim - (len(template) + 2) - 1 |
| 334 | + ok_param = "a" * nextra |
| 335 | + bad_param = "a" * (nextra + 1) |
| 336 | + |
| 337 | + unexpanded_query = template + "?" |
| 338 | + with self.check_stmt_trace(cx, [unexpanded_query]): |
| 339 | + cx.execute(unexpanded_query, (bad_param,)) |
| 340 | + |
| 341 | + expanded_query = f"{template}'{ok_param}'" |
| 342 | + with self.check_stmt_trace(cx, [expanded_query]): |
| 343 | + cx.execute(unexpanded_query, (ok_param,)) |
| 344 | + |
| 345 | + @with_tracebacks(ZeroDivisionError, regex="division by zero") |
| 346 | + def test_trace_bad_handler(self): |
| 347 | + with memory_database() as cx: |
| 348 | + cx.set_trace_callback(lambda stmt: 5/0) |
| 349 | + cx.execute("select 1") |
| 350 | + |
292 | 351 |
|
293 | 352 | if __name__ == "__main__": |
294 | 353 | unittest.main() |
0 commit comments