|
21 | 21 | # misrepresented as being the original software. |
22 | 22 | # 3. This notice may not be removed or altered from any source distribution. |
23 | 23 |
|
| 24 | +import contextlib |
| 25 | +import functools |
| 26 | +import io |
24 | 27 | import unittest |
25 | 28 | import unittest.mock |
26 | 29 | import gc |
27 | 30 | import sqlite3 as sqlite |
28 | 31 |
|
| 32 | +def with_tracebacks(strings): |
| 33 | + """Convenience decorator for testing callback tracebacks.""" |
| 34 | + strings.append('Traceback') |
| 35 | + |
| 36 | + def decorator(func): |
| 37 | + @functools.wraps(func) |
| 38 | + def wrapper(self, *args, **kwargs): |
| 39 | + # First, run the test with traceback enabled. |
| 40 | + sqlite.enable_callback_tracebacks(True) |
| 41 | + buf = io.StringIO() |
| 42 | + with contextlib.redirect_stderr(buf): |
| 43 | + func(self, *args, **kwargs) |
| 44 | + tb = buf.getvalue() |
| 45 | + for s in strings: |
| 46 | + self.assertIn(s, tb) |
| 47 | + |
| 48 | + # Then run the test with traceback disabled. |
| 49 | + sqlite.enable_callback_tracebacks(False) |
| 50 | + func(self, *args, **kwargs) |
| 51 | + return wrapper |
| 52 | + return decorator |
| 53 | + |
29 | 54 | def func_returntext(): |
30 | 55 | return "foo" |
31 | 56 | def func_returnunicode(): |
@@ -228,6 +253,7 @@ def test_func_return_long_long(self): |
228 | 253 | val = cur.fetchone()[0] |
229 | 254 | self.assertEqual(val, 1<<31) |
230 | 255 |
|
| 256 | + @with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError']) |
231 | 257 | def test_func_exception(self): |
232 | 258 | cur = self.con.cursor() |
233 | 259 | with self.assertRaises(sqlite.OperationalError) as cm: |
@@ -387,20 +413,23 @@ def test_aggr_no_finalize(self): |
387 | 413 | val = cur.fetchone()[0] |
388 | 414 | self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") |
389 | 415 |
|
| 416 | + @with_tracebacks(['__init__', '5/0', 'ZeroDivisionError']) |
390 | 417 | def test_aggr_exception_in_init(self): |
391 | 418 | cur = self.con.cursor() |
392 | 419 | with self.assertRaises(sqlite.OperationalError) as cm: |
393 | 420 | cur.execute("select excInit(t) from test") |
394 | 421 | val = cur.fetchone()[0] |
395 | 422 | self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") |
396 | 423 |
|
| 424 | + @with_tracebacks(['step', '5/0', 'ZeroDivisionError']) |
397 | 425 | def test_aggr_exception_in_step(self): |
398 | 426 | cur = self.con.cursor() |
399 | 427 | with self.assertRaises(sqlite.OperationalError) as cm: |
400 | 428 | cur.execute("select excStep(t) from test") |
401 | 429 | val = cur.fetchone()[0] |
402 | 430 | self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") |
403 | 431 |
|
| 432 | + @with_tracebacks(['finalize', '5/0', 'ZeroDivisionError']) |
404 | 433 | def test_aggr_exception_in_finalize(self): |
405 | 434 | cur = self.con.cursor() |
406 | 435 | with self.assertRaises(sqlite.OperationalError) as cm: |
@@ -502,6 +531,14 @@ def authorizer_cb(action, arg1, arg2, dbname, source): |
502 | 531 | raise ValueError |
503 | 532 | return sqlite.SQLITE_OK |
504 | 533 |
|
| 534 | + @with_tracebacks(['authorizer_cb', 'ValueError']) |
| 535 | + def test_table_access(self): |
| 536 | + super().test_table_access() |
| 537 | + |
| 538 | + @with_tracebacks(['authorizer_cb', 'ValueError']) |
| 539 | + def test_column_access(self): |
| 540 | + super().test_table_access() |
| 541 | + |
505 | 542 | class AuthorizerIllegalTypeTests(AuthorizerTests): |
506 | 543 | @staticmethod |
507 | 544 | def authorizer_cb(action, arg1, arg2, dbname, source): |
|
0 commit comments