From af8162346a4849c341f5dae3f0f519c0a13f9864 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 19 Apr 2020 17:04:27 +0200 Subject: [PATCH 001/194] Bump version number, start working on 5.2 --- docs/about.txt | 2 +- docs/announce.rst | 8 ++++---- docs/conf.py | 4 ++-- docs/contents/changelog.rst | 5 +++++ setup.py | 4 ++-- 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/about.txt b/docs/about.txt index 54e98d39..6d15e3c3 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -36,7 +36,7 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.1.2 needs PostgreSQL 9.0 to 9.6 or 10 to 12, and +The current version PyGreSQL 5.2 needs PostgreSQL 9.0 to 9.6 or 10 to 12, and Python 2.6, 2.7 or 3.3 to 3.8. If you need to support older PostgreSQL versions or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index 73ddce93..0efe650a 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -2,11 +2,11 @@ PyGreSQL Announcements ====================== ---------------------------------- -Release of PyGreSQL version 5.1.2 ---------------------------------- +------------------------------- +Release of PyGreSQL version 5.2 +------------------------------- -Release 5.1.2 of PyGreSQL. +Release 5.2 of PyGreSQL. It is available at: https://pypi.org/project/PyGreSQL/. diff --git a/docs/conf.py b/docs/conf.py index 7c0919a5..add72e63 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,9 +68,9 @@ # built documents. # # The short X.Y version. -version = '5.1' +version = '5.2' # The full version, including alpha/beta/rc tags. -release = '5.1.2' +release = '5.2' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index d5ef4f05..a05e1772 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,11 @@ ChangeLog ========= +Version 5.2 (to be released) +---------------------------- +- ... + + Version 5.1.2 (2020-04-19) -------------------------- - Improved handling of build_ext options for disabling certain features. diff --git a/setup.py b/setup.py index 3e8d9c9e..368f4b59 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # # Please see the LICENSE.TXT file for specific restrictions. -"""Setup script for PyGreSQL version 5.1.2 +"""Setup script for PyGreSQL version 5.2 PyGreSQL is an open-source Python module that interfaces to a PostgreSQL database. It embeds the PostgreSQL query library to allow @@ -52,7 +52,7 @@ from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib -version = '5.1.2' +version = '5.2' if (not (2, 6) <= sys.version_info[:2] < (3, 0) and not (3, 3) <= sys.version_info[:2] < (4, 0)): From f6bbd69fb709dd8d1a4d67bcb7ad0e270034e2d6 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 19 Apr 2020 20:10:57 +0200 Subject: [PATCH 002/194] Desupport old Python versions, require 2.7 or 3.5+ --- .travis.yml | 2 +- docs/about.txt | 4 +- docs/announce.rst | 2 +- docs/contents/changelog.rst | 2 +- docs/contents/install.rst | 2 +- pg.py | 126 ++++++++----------------------- pgdb.py | 9 +-- setup.py | 5 +- tests/test_classic_connection.py | 8 +- tests/test_classic_dbwrapper.py | 39 +++------- tests/test_dbapi20.py | 26 +------ tox.ini | 7 +- 12 files changed, 56 insertions(+), 176 deletions(-) diff --git a/.travis.yml b/.travis.yml index c9976f0d..4bb9dd99 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,10 +5,10 @@ language: python python: - "2.7" - - "3.4" - "3.5" - "3.6" - "3.7" + - "3.8" install: - pip install . diff --git a/docs/about.txt b/docs/about.txt index 6d15e3c3..0274f84f 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -37,6 +37,6 @@ D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. The current version PyGreSQL 5.2 needs PostgreSQL 9.0 to 9.6 or 10 to 12, and -Python 2.6, 2.7 or 3.3 to 3.8. If you need to support older PostgreSQL versions -or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that +Python 2.7 or 3.5 to 3.8. If you need to support older PostgreSQL versions or +older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index 0efe650a..4c0194d8 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -23,7 +23,7 @@ This version has been built and unit tested on: - Ubuntu - Windows 7 and 10 with both MinGW and Visual Studio - PostgreSQL 9.0 to 9.6 and 10 to 12 (32 and 64bit) - - Python 2.6, 2.7 and 3.3 to 3.8 (32 and 64bit) + - Python 2.7 and 3.5 to 3.8 (32 and 64bit) | D'Arcy J.M. Cain | darcy@PyGreSQL.org diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index a05e1772..3af7d802 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -3,7 +3,7 @@ ChangeLog Version 5.2 (to be released) ---------------------------- -- ... +- We now Python version 2.7 or 3.5 and newer Version 5.1.2 (2020-04-19) diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 6cdee9da..126e791b 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -11,7 +11,7 @@ If you are on Windows, make sure that the directory that contains libpq.dll is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -2.6, 2.7 and 3.3 to 3.8, and PostgreSQL versions 9.0 to 9.6 and 10 to 12. +2.7 and 3.5 to 3.8, and PostgreSQL versions 9.0 to 9.6 and 10 to 12. PyGreSQL will be installed as three modules, a shared library called _pg.so (on Linux) or a DLL called _pg.pyd (on Windows), and two pure diff --git a/pg.py b/pg.py index b0fa6674..a8fd0c68 100644 --- a/pg.py +++ b/pg.py @@ -75,8 +75,7 @@ from datetime import date, time, datetime, timedelta, tzinfo from decimal import Decimal from math import isnan, isinf -from collections import namedtuple -from keyword import iskeyword +from collections import namedtuple, OrderedDict from operator import itemgetter from functools import partial from re import compile as regex @@ -179,91 +178,6 @@ def wrapper(arg): # Auxiliary classes and functions that are independent from a DB connection: -try: - from collections import OrderedDict -except ImportError: # Python 2.6 or 3.0 - OrderedDict = dict - - - class AttrDict(dict): - """Simple read-only ordered dictionary for storing attribute names.""" - - def __init__(self, *args, **kw): - if len(args) > 1 or kw: - raise TypeError - items = args[0] if args else [] - if isinstance(items, dict): - raise TypeError - items = list(items) - self._keys = [item[0] for item in items] - dict.__init__(self, items) - self._read_only = True - error = self._read_only_error - self.clear = self.update = error - self.pop = self.setdefault = self.popitem = error - - def __setitem__(self, key, value): - if self._read_only: - self._read_only_error() - dict.__setitem__(self, key, value) - - def __delitem__(self, key): - if self._read_only: - self._read_only_error() - dict.__delitem__(self, key) - - def __iter__(self): - return iter(self._keys) - - def keys(self): - return list(self._keys) - - def values(self): - return [self[key] for key in self] - - def items(self): - return [(key, self[key]) for key in self] - - def iterkeys(self): - return self.__iter__() - - def itervalues(self): - return iter(self.values()) - - def iteritems(self): - return iter(self.items()) - - @staticmethod - def _read_only_error(*args, **kw): - raise TypeError('This object is read-only') - -else: - - class AttrDict(OrderedDict): - """Simple read-only ordered dictionary for storing attribute names.""" - - def __init__(self, *args, **kw): - self._read_only = False - OrderedDict.__init__(self, *args, **kw) - self._read_only = True - error = self._read_only_error - self.clear = self.update = error - self.pop = self.setdefault = self.popitem = error - - def __setitem__(self, key, value): - if self._read_only: - self._read_only_error() - OrderedDict.__setitem__(self, key, value) - - def __delitem__(self, key): - if self._read_only: - self._read_only_error() - OrderedDict.__delitem__(self, key) - - @staticmethod - def _read_only_error(*args, **kw): - raise TypeError('This object is read-only') - try: from inspect import signature except ImportError: # Python < 3.3 @@ -356,8 +270,8 @@ def __init__(self): self[key] = typ self['_%s' % key] = '%s[]' % typ - # this could be a static method in Python > 2.6 - def __missing__(self, key): + @staticmethod + def __missing__(key): return 'text' _simpletypes = _SimpleTypes() @@ -428,6 +342,32 @@ class Literal(str): """Wrapper class for marking literal SQL values.""" +class AttrDict(OrderedDict): + """Simple read-only ordered dictionary for storing attribute names.""" + + def __init__(self, *args, **kw): + self._read_only = False + OrderedDict.__init__(self, *args, **kw) + self._read_only = True + error = self._read_only_error + self.clear = self.update = error + self.pop = self.setdefault = self.popitem = error + + def __setitem__(self, key, value): + if self._read_only: + self._read_only_error() + OrderedDict.__setitem__(self, key, value) + + def __delitem__(self, key): + if self._read_only: + self._read_only_error() + OrderedDict.__delitem__(self, key) + + @staticmethod + def _read_only_error(*args, **kw): + raise TypeError('This object is read-only') + + class Adapter: """Class providing methods for adapting parameters to the database.""" @@ -1328,13 +1268,7 @@ def typecast(self, value, typ): def _row_factory(names): """Get a namedtuple factory for row results with the given names.""" try: - try: - return namedtuple('Row', names, rename=True)._make - except TypeError: # Python 2.6 and 3.0 do not support rename - names = [v if _re_fieldname.match(v) and not iskeyword(v) - else 'column_%d' % (n,) - for n, v in enumerate(names)] - return namedtuple('Row', names)._make + return namedtuple('Row', names, rename=True)._make except ValueError: # there is still a problem with the field names names = ['column_%d' % (n,) for n in range(len(names))] return namedtuple('Row', names)._make diff --git a/pgdb.py b/pgdb.py index 021ba444..528d3037 100644 --- a/pgdb.py +++ b/pgdb.py @@ -112,7 +112,6 @@ except ImportError: # Python < 3.3 from collections import Iterable from collections import namedtuple -from keyword import iskeyword from functools import partial from re import compile as regex from json import loads as jsondecode, dumps as jsonencode @@ -867,13 +866,7 @@ def _op_error(msg): def _row_factory(names): """Get a namedtuple factory for row results with the given names.""" try: - try: - return namedtuple('Row', names, rename=True)._make - except TypeError: # Python 2.6 and 3.0 do not support rename - names = [v if _re_fieldname.match(v) and not iskeyword(v) - else 'column_%d' % (n,) - for n, v in enumerate(names)] - return namedtuple('Row', names)._make + return namedtuple('Row', names, rename=True)._make except ValueError: # there is still a problem with the field names names = ['column_%d' % (n,) for n in range(len(names))] return namedtuple('Row', names)._make diff --git a/setup.py b/setup.py index 368f4b59..fc68c176 100755 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ * PostgreSQL pg_config tool (usually included in the devel package) (the Windows installer has it as part of the database server feature) -PyGreSQL currently supports Python versions 2.6, 2.7 and 3.3 to 3.8, +PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.8, and PostgreSQL versions 9.0 to 9.6 and 10 to 12. Use as follows: @@ -232,11 +232,8 @@ def finalize_options(self): "Programming Language :: C", 'Programming Language :: Python', 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.6', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 3f0e6bbf..bd8f9d90 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -462,12 +462,8 @@ def testNamedresultWithGoodFieldnames(self): self.assertEqual(v._fields, ('snake_case_alias', 'CamelCaseAlias')) def testNamedresultWithBadFieldnames(self): - try: - r = namedtuple('Bad', ['?'] * 6, rename=True) - except TypeError: # Python 2.6 or 3.0 - fields = tuple('column_%d' % n for n in range(6)) - else: - fields = r._fields + r = namedtuple('Bad', ['?'] * 6, rename=True) + fields = r._fields q = ('select 3 as "0alias", 4 as _alias, 5 as "alias$", 6 as "alias?",' ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one') result = [tuple(range(3, 10))] diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index d9222560..d6a95883 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -56,14 +56,11 @@ except NameError: # Python >= 3.0 unicode = str -try: - from collections import OrderedDict -except ImportError: # Python 2.6 or 3.0 - OrderedDict = dict +from collections import OrderedDict if str is bytes: # noinspection PyUnresolvedReferences from StringIO import StringIO -else: +else: # Python >= 3.0 from io import StringIO windows = os.name == 'nt' @@ -226,11 +223,7 @@ def testAllDBAttributes(self): 'unescape_bytea', 'update', 'upsert', 'use_regtypes', 'user', ] - # __dir__ is not called in Python 2.6 for old-style classes - db_attributes = dir(self.db) if hasattr( - self.db.__class__, '__class__') else self.db.__dir__() - db_attributes = [a for a in db_attributes - if not a.startswith('_')] + db_attributes = [a for a in self.db.__dir__() if not a.startswith('_')] self.assertEqual(attributes, db_attributes) def testAttributeDb(self): @@ -1005,11 +998,6 @@ def testQueryFormatted(self): # test with tuple, inline q = f("select %s, %s, %s, %s", (3, 2.5, 'hello', True), inline=True) r = q.getresult()[0] - if isinstance(r[1], Decimal): - # Python 2.6 cannot compare float and Decimal - r = list(r) - r[1] = float(r[1]) - r = tuple(r) self.assertEqual(r, (3, 2.5, 'hello', t)) # test with dict q = f("select %(a)s::int, %(b)s::real, %(c)s::text, %(d)s::bool", @@ -2944,8 +2932,7 @@ def testGetAsDict(self): self.assertEqual(row.rgb, t[0]) self.assertEqual(row.name, t[1]) self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1])) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) + self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname='rgb') self.assertIsInstance(r, OrderedDict) expected = OrderedDict((row[1], (row[0], row[2])) @@ -2962,8 +2949,7 @@ def testGetAsDict(self): self.assertEqual(row.id, t[0]) self.assertEqual(row.name, t[1]) self.assertEqual(row._asdict(), dict(id=t[0], name=t[1])) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) + self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname=['id', 'rgb']) self.assertIsInstance(r, OrderedDict) expected = OrderedDict((row[:2], row[2:]) for row in colors) @@ -2983,8 +2969,7 @@ def testGetAsDict(self): if named: self.assertEqual(row.name, t[0]) self.assertEqual(row._asdict(), dict(name=t[0])) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) + self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True) self.assertIsInstance(r, OrderedDict) expected = OrderedDict((row[:2], row[2]) for row in colors) @@ -2995,8 +2980,7 @@ def testGetAsDict(self): self.assertIsInstance(row, str) t = expected[key] self.assertEqual(row, t) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) + self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True) self.assertIsInstance(r, OrderedDict) expected = OrderedDict((row[1], row[2]) @@ -3008,8 +2992,7 @@ def testGetAsDict(self): self.assertIsInstance(row, str) t = expected[key] self.assertEqual(row, t) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) + self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, what='id, name', where="rgb like '#b%'", scalar=True) self.assertIsInstance(r, OrderedDict) @@ -3021,8 +3004,7 @@ def testGetAsDict(self): self.assertIsInstance(row, str) t = expected[key] self.assertEqual(row, t) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) + self.assertEqual(r.keys(), expected.keys()) expected = r r = get_as_dict(table, what=['name', 'id'], where=['id > 1', 'id < 4', "rgb like '#b%'", @@ -3050,8 +3032,7 @@ def testGetAsDict(self): r = get_as_dict(table, order=False) self.assertIsInstance(r, dict) self.assertEqual(r, expected) - if dict is not OrderedDict: # Python > 2.6 - self.assertNotIsInstance(self, OrderedDict) + self.assertNotIsInstance(self, OrderedDict) # test with arbitrary from clause from_table = '(select id, lower(name) as n2 from "%s") as t2' % table # primary key must be passed explicitly in this case diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 81f5c73e..eb37033b 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -38,10 +38,7 @@ except NameError: # Python >= 3.0 long = int -try: - from collections import OrderedDict -except ImportError: # Python 2.6 or 3.0 - OrderedDict = None +from collections import OrderedDict class PgBitString: @@ -244,20 +241,12 @@ def test_cursor_with_unnamed_columns(self): res = cur.fetchone() self.assertIsInstance(res, tuple) self.assertEqual(res, (1, 2, 3)) - old_py = OrderedDict is None # Python 2.6 or 3.0 - # old Python versions cannot rename tuple fields with underscore - if old_py: - self.assertEqual(res._fields, ('column_0', 'column_1', 'column_2')) - else: - self.assertEqual(res._fields, ('_0', '_1', '_2')) + self.assertEqual(res._fields, ('_0', '_1', '_2')) cur.execute("select 1 as one, 2, 3 as three") res = cur.fetchone() self.assertIsInstance(res, tuple) self.assertEqual(res, (1, 2, 3)) - if old_py: # cannot auto rename with underscore - self.assertEqual(res._fields, ('one', 'column_1', 'three')) - else: - self.assertEqual(res._fields, ('one', '_1', 'three')) + self.assertEqual(res._fields, ('one', '_1', 'three')) def test_cursor_with_badly_named_columns(self): con = self._connect() @@ -266,21 +255,14 @@ def test_cursor_with_badly_named_columns(self): res = cur.fetchone() self.assertIsInstance(res, tuple) self.assertEqual(res, (1, 2)) - old_py = OrderedDict is None # Python 2.6 or 3.0 - if old_py: - self.assertEqual(res._fields, ('abc', 'column_1')) - else: - self.assertEqual(res._fields, ('abc', '_1')) + self.assertEqual(res._fields, ('abc', '_1')) cur.execute('select 1 as snake_case, 2 as "CamelCase",' ' 3 as "kebap-case", 4 as "_bad", 5 as "0bad", 6 as "bad$"') res = cur.fetchone() self.assertIsInstance(res, tuple) self.assertEqual(res, (1, 2, 3, 4, 5, 6)) - # old Python versions cannot rename tuple fields with underscore self.assertEqual(res._fields[:2], ('snake_case', 'CamelCase')) fields = ('_2', '_3', '_4', '_5') - if old_py: - fields = tuple('column' + field for field in fields) self.assertEqual(res._fields[2:], fields) def test_colnames(self): diff --git a/tox.ini b/tox.ini index c8be87b9..cbf1c161 100644 --- a/tox.ini +++ b/tox.ini @@ -1,12 +1,9 @@ # config file for tox [tox] -envlist = py{26,27,33,34,35,36,37,38} +envlist = py{27,35,36,37,38} [testenv] -deps = - py26: unittest2 commands = python setup.py clean --all build_ext --force --inplace --strict - py26: unit2 discover {posargs} - py{27,33,34,35,36,37,38}: python -m unittest discover {posargs} + python -m unittest discover {posargs} From 517a40b5e8e83950f97a0bee65225c634b3bc639 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 19 Apr 2020 20:38:55 +0200 Subject: [PATCH 003/194] unittest2 not needed any more, some PEP8 formatting --- pg.py | 49 +++++++++++---------- pgdb.py | 21 ++++++--- tests/__init__.py | 7 +-- tests/dbapi20.py | 5 +-- tests/test_classic.py | 14 +++--- tests/test_classic_connection.py | 16 +++---- tests/test_classic_dbwrapper.py | 70 ++++++++++++++++-------------- tests/test_classic_functions.py | 5 +-- tests/test_classic_largeobj.py | 5 +-- tests/test_classic_notification.py | 6 +-- tests/test_dbapi20.py | 61 ++++++++++++++------------ tests/test_dbapi20_copy.py | 18 ++++---- tests/test_tutorial.py | 5 +-- 13 files changed, 139 insertions(+), 143 deletions(-) diff --git a/pg.py b/pg.py index a8fd0c68..e2ef2aa4 100644 --- a/pg.py +++ b/pg.py @@ -373,7 +373,8 @@ class Adapter: _bool_true_values = frozenset('t true 1 y yes on'.split()) - _date_literals = frozenset('current_date current_time' + _date_literals = frozenset( + 'current_date current_time' ' current_timestamp localtime localtimestamp'.split()) _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$') @@ -1100,6 +1101,7 @@ def dateformat(self): def create_array_cast(self, basecast): """Create an array typecast for the given base cast.""" cast_array = self['anyarray'] + def cast(v): return cast_array(v, basecast) return cast @@ -1108,6 +1110,7 @@ def create_record_cast(self, name, fields, casts): """Create a named record typecast for the given fields and casts.""" cast_record = self['record'] record = namedtuple(name, fields) + def cast(v): return record(*cast_record(v, casts)) return cast @@ -1260,6 +1263,7 @@ def typecast(self, value, typ): _re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$') + # The result rows for database operations are returned as named tuples # by default. Since creating namedtuple classes is a somewhat expensive # operation, we cache up to 1024 of these classes by default. @@ -1977,13 +1981,13 @@ def pkey(self, table, composite=False, flush=False): pkey = pkeys[table] except KeyError: # cache miss, check the database q = ("SELECT a.attname, a.attnum, i.indkey" - " FROM pg_catalog.pg_index i" - " JOIN pg_catalog.pg_attribute a" - " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" - " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" - " AND NOT a.attisdropped" - " WHERE i.indrelid OPERATOR(pg_catalog.=) %s::regclass" - " AND i.indisprimary ORDER BY a.attnum") % ( + " FROM pg_catalog.pg_index i" + " JOIN pg_catalog.pg_attribute a" + " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" + " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" + " AND NOT a.attisdropped" + " WHERE i.indrelid OPERATOR(pg_catalog.=) %s::regclass" + " AND i.indisprimary ORDER BY a.attnum") % ( _quote_if_unqualified('$1', table),) pkey = self.db.query(q, (table,)).getresult() if not pkey: @@ -2003,9 +2007,8 @@ def pkey(self, table, composite=False, flush=False): def get_databases(self): """Get list of databases in the system.""" - return [s[0] for s in - self.db.query( - 'SELECT datname FROM pg_catalog.pg_database').getresult()] + return [s[0] for s in self.db.query( + 'SELECT datname FROM pg_catalog.pg_database').getresult()] def get_relations(self, kinds=None, system=False): """Get list of relations in connected database of specified kinds. @@ -2019,17 +2022,17 @@ def get_relations(self, kinds=None, system=False): where = [] if kinds: where.append("r.relkind IN (%s)" % - ','.join("'%s'" % k for k in kinds)) + ','.join("'%s'" % k for k in kinds)) if not system: where.append("s.nspname NOT SIMILAR" - " TO 'pg/_%|information/_schema' ESCAPE '/'") + " TO 'pg/_%|information/_schema' ESCAPE '/'") where = " WHERE %s" % ' AND '.join(where) if where else '' q = ("SELECT pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" - " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" - " FROM pg_catalog.pg_class r" - " JOIN pg_catalog.pg_namespace s" - " ON s.oid OPERATOR(pg_catalog.=) r.relnamespace%s" - " ORDER BY s.nspname, r.relname") % where + " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" + " FROM pg_catalog.pg_class r" + " JOIN pg_catalog.pg_namespace s" + " ON s.oid OPERATOR(pg_catalog.=) r.relnamespace%s" + " ORDER BY s.nspname, r.relname") % where return [r[0] for r in self.db.query(q).getresult()] def get_tables(self, system=False): @@ -2371,7 +2374,7 @@ def upsert(self, table, row=None, **kw): do = 'update set %s' % ', '.join(update) if update else 'nothing' ret = 'oid, *' if qoid else '*' q = ('INSERT INTO %s AS included (%s) VALUES (%s)' - ' ON CONFLICT (%s) DO %s RETURNING %s') % ( + ' ON CONFLICT (%s) DO %s RETURNING %s') % ( self._escape_qualified_name(table), names, values, target, do, ret) self._do_debug(q, params) @@ -2522,7 +2525,7 @@ def truncate(self, table, restart=False, cascade=False, only=False): return self.db.query(q) def get_as_list(self, table, what=None, where=None, - order=None, limit=None, offset=None, scalar=False): + order=None, limit=None, offset=None, scalar=False): """Get a table as a list. This gets a convenient representation of the table as a list @@ -2588,7 +2591,7 @@ def get_as_list(self, table, what=None, where=None, return res def get_as_dict(self, table, keyname=None, what=None, where=None, - order=None, limit=None, offset=None, scalar=False): + order=None, limit=None, offset=None, scalar=False): """Get a table as a dictionary. This method is similar to get_as_list(), but returns the table @@ -2678,8 +2681,8 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, rows = _namediter(_MemoryQuery(rows, fields)) return cls(zip(keys, rows)) - def notification_handler(self, - event, callback, arg_dict=None, timeout=None, stop_event=None): + def notification_handler(self, event, callback, + arg_dict=None, timeout=None, stop_event=None): """Get notification handler that will run the given callback.""" return NotificationHandler(self, event, callback, arg_dict, timeout, stop_event) diff --git a/pgdb.py b/pgdb.py index 528d3037..bf36d8e4 100644 --- a/pgdb.py +++ b/pgdb.py @@ -636,6 +636,7 @@ def reset(self, typ=None): def create_array_cast(self, basecast): """Create an array typecast for the given base cast.""" cast_array = self['anyarray'] + def cast(v): return cast_array(v, basecast) return cast @@ -644,6 +645,7 @@ def create_record_cast(self, name, fields, casts): """Create a named record typecast for the given fields and casts.""" cast_record = self['record'] record = namedtuple(name, fields) + def cast(v): return record(*cast_record(v, casts)) return cast @@ -753,11 +755,13 @@ def __init__(self, cnx): self._typecasts.connection = cnx if cnx.server_version < 80400: # older remote databases (not officially supported) - self._query_pg_type = ("SELECT oid, typname," + self._query_pg_type = ( + "SELECT oid, typname," " typlen, typtype, null as typcategory, typdelim, typrelid" " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) %s") else: - self._query_pg_type = ("SELECT oid, typname," + self._query_pg_type = ( + "SELECT oid, typname," " typlen, typtype, typcategory, typdelim, typrelid" " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) %s") @@ -778,8 +782,9 @@ def __missing__(self, key): if not res: raise KeyError('Type %s could not be found' % (key,)) res = res[0] - type_code = TypeCode.create(int(res[0]), res[1], - int(res[2]), res[3], res[4], res[5], int(res[6])) + type_code = TypeCode.create( + int(res[0]), res[1], int(res[2]), + res[3], res[4], res[5], int(res[6])) self[type_code.oid] = self[str(type_code)] = type_code return type_code @@ -798,13 +803,14 @@ def get_fields(self, typ): return None if not typ.relid: return None # this type is not composite - self._src.execute("SELECT attname, atttypid" + self._src.execute( + "SELECT attname, atttypid" " FROM pg_catalog.pg_attribute" " WHERE attrelid OPERATOR(pg_catalog.=) %s" " AND attnum OPERATOR(pg_catalog.>) 0" " AND NOT attisdropped ORDER BY attnum" % (typ.relid,)) return [FieldInfo(name, self.get(int(oid))) - for name, oid in self._src.fetch(-1)] + for name, oid in self._src.fetch(-1)] def get_typecast(self, typ): """Get the typecast function for the given database type.""" @@ -858,6 +864,7 @@ def _op_error(msg): _re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$') + # The result rows for database operations are returned as named tuples # by default. Since creating namedtuple classes is a somewhat expensive # operation, we cache up to 1024 of these classes by default. @@ -1737,7 +1744,7 @@ def __ne__(self, other): BINARY = Type('bytea') NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money') DATETIME = Type('date time timetz timestamp timestamptz interval' - ' abstime reltime') # these are very old + ' abstime reltime') # these are very old ROWID = Type('oid') diff --git a/tests/__init__.py b/tests/__init__.py index 38f807e0..f3070dd1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,10 +3,7 @@ You can specify your local database settings in LOCAL_PyGreSQL.py. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest if not (hasattr(unittest, 'skip') and hasattr(unittest.TestCase, 'setUpClass') @@ -18,4 +15,4 @@ def discover(): loader = unittest.TestLoader() suite = loader.discover('.') - return suite \ No newline at end of file + return suite diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 0656cddf..5d77267e 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -6,10 +6,7 @@ __version__ = '1.5' -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest import time diff --git a/tests/test_classic.py b/tests/test_classic.py index bb5133ee..397f3fd4 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -3,10 +3,7 @@ from __future__ import print_function -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest import sys from functools import partial @@ -65,17 +62,20 @@ def setUp(self): except Error: pass try: - db.query("CREATE TABLE %s._test_schema " + db.query( + "CREATE TABLE %s._test_schema " "(%s int PRIMARY KEY)" % (t, t)) except Error: db.query("DELETE FROM %s._test_schema" % t) try: - db.query("CREATE TABLE _test_schema " + db.query( + "CREATE TABLE _test_schema " "(_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)") except Error: db.query("DELETE FROM _test_schema") try: - db.query("CREATE VIEW _test_vschema AS " + db.query( + "CREATE VIEW _test_vschema AS " "SELECT _test, 'abc'::text AS _test2 FROM _test_schema") except Error: pass diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index bd8f9d90..01559c0d 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -10,10 +10,7 @@ These tests need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest import threading import time import os @@ -21,7 +18,7 @@ from collections import namedtuple try: from collections.abc import Iterable -except ImportError: +except ImportError: # Python < 3.3 from collections import Iterable from decimal import Decimal @@ -465,7 +462,7 @@ def testNamedresultWithBadFieldnames(self): r = namedtuple('Bad', ['?'] * 6, rename=True) fields = r._fields q = ('select 3 as "0alias", 4 as _alias, 5 as "alias$", 6 as "alias?",' - ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one') + ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one') result = [tuple(range(3, 10))] r = self.c.query(q).namedresult() self.assertEqual(r, result) @@ -1573,9 +1570,10 @@ def setUpClass(cls): c = connect() c.query("drop table if exists test cascade") c.query("create table test (" - "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time," - "d numeric, f4 real, f8 double precision, m money," - "c char(1), v4 varchar(4), c4 char(4), t text)") + "i2 smallint, i4 integer, i8 bigint," + " b boolean, dt date, ti time," + "d numeric, f4 real, f8 double precision, m money," + "c char(1), v4 varchar(4), c4 char(4), t text)") # Check whether the test database uses SQL_ASCII - this means # that it does not consider encoding when calculating lengths. c.query("set client_encoding=utf8") diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index d6a95883..abec52c1 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -10,11 +10,7 @@ These tests need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest - +import unittest import os import sys import gc @@ -1465,8 +1461,9 @@ def testGetAttnamesIsAttrDict(self): r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' - self.createTable(table, 'n int, alpha smallint, v varchar(3),' - ' gamma char(5), tau text, beta bool') + self.createTable( + table, 'n int, alpha smallint, v varchar(3),' + ' gamma char(5), tau text, beta bool') r = get_attnames(table) self.assertIsInstance(r, AttrDict) if self.regtypes: @@ -1492,7 +1489,8 @@ def testHasTablePrivilege(self): self.assertEqual(can('test', 'delete'), True) self.assertRaises(pg.DataError, can, 'test', 'foobar') self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist') - r = self.db.query('select rolsuper FROM pg_roles' + r = self.db.query( + 'select rolsuper FROM pg_roles' ' where rolname=current_user').getresult()[0][0] if not pg.get_bool(): r = r == 't' @@ -1678,7 +1676,8 @@ def testGetFromView(self): def testGetLittleBobbyTables(self): get = self.db.get query = self.db.query - self.createTable('test_students', + self.createTable( + 'test_students', 'firstname varchar primary key, nickname varchar, grade char(2)', values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'), ('Robert', 'Little Bobby Tables', 'D-')]) @@ -1714,8 +1713,8 @@ def testInsert(self): bool_on = pg.get_bool() decimal = pg.get_decimal() table = 'insert_test_table' - self.createTable(table, - 'i2 smallint, i4 integer, i8 bigint,' + self.createTable( + table, 'i2 smallint, i4 integer, i8 bigint,' ' d numeric, f4 real, f8 double precision, m money,' ' v4 varchar(4), c4 char(4), t text,' ' b boolean, ts timestamp') @@ -1789,8 +1788,8 @@ def testInsert(self): if ts == 'current_timestamp': ts = data['ts'] self.assertIsInstance(ts, datetime) - self.assertEqual(ts.strftime('%Y-%m-%d'), - strftime('%Y-%m-%d')) + self.assertEqual( + ts.strftime('%Y-%m-%d'), strftime('%Y-%m-%d')) else: ts = datetime.strptime(ts, '%Y-%m-%d %H:%M:%S') expect['ts'] = ts @@ -2520,14 +2519,14 @@ def testDeleteWithCompositeKey(self): r = query('select t from "%s" where n=3' % table).getresult()[0][0] self.assertEqual(r, 'c') table = 'delete_test_table_2' - self.createTable(table, - 'n integer, m integer, t text, primary key (n, m)', - values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) - for n in range(3) for m in range(2)]) + self.createTable( + table, 'n integer, m integer, t text, primary key (n, m)', + values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) + for n in range(3) for m in range(2)]) self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b')) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1) r = [r[0] for r in query('select t from "%s" where n=2' - ' order by m' % table).getresult()] + ' order by m' % table).getresult()] self.assertEqual(r, ['c']) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0) r = [r[0] for r in query('select t from "%s" where n=3' @@ -2542,7 +2541,8 @@ def testDeleteWithQuotedNames(self): delete = self.db.delete query = self.db.query table = 'test table for delete()' - self.createTable(table, '"Prime!" smallint primary key,' + self.createTable( + table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(19, 5005, 'Yes!')]) r = {'Prime!': 17} @@ -2559,10 +2559,11 @@ def testDeleteWithQuotedNames(self): def testDeleteReferenced(self): delete = self.db.delete query = self.db.query - self.createTable('test_parent', - 'n smallint primary key', values=range(3)) - self.createTable('test_child', - 'n smallint primary key references test_parent', values=range(3)) + self.createTable( + 'test_parent', 'n smallint primary key', values=range(3)) + self.createTable( + 'test_child', 'n smallint primary key references test_parent', + values=range(3)) q = ("select (select count(*) from test_parent)," " (select count(*) from test_child)") self.assertEqual(query(q).getresult()[0], (3, 3)) @@ -2760,7 +2761,8 @@ def testTruncateOnly(self): truncate(['test_parent', 'test_parent_2'], only=False) r = query(q).getresult()[0] self.assertEqual(r, (0, 0, 0, 0)) - self.assertRaises(ValueError, truncate, + self.assertRaises( + ValueError, truncate, ['test_parent*', 'test_child'], only=[True, False]) truncate(['test_parent*', 'test_child'], only=[False, True]) @@ -2794,8 +2796,8 @@ def testGetAsList(self): named = hasattr(r, 'colname') names = [(1, 'Homer'), (2, 'Marge'), (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')] - self.createTable(table, - 'id smallint primary key, name varchar', values=names) + self.createTable( + table, 'id smallint primary key, name varchar', values=names) r = get_as_list(table) self.assertIsInstance(r, list) self.assertEqual(r, names) @@ -2823,8 +2825,8 @@ def testGetAsList(self): r = get_as_list(table, what='name', where="name like 'Ma%'") self.assertIsInstance(r, list) self.assertEqual(r, [('Maggie',), ('Marge',)]) - r = get_as_list(table, what='name', - where=["name like 'Ma%'", "name like '%r%'"]) + r = get_as_list( + table, what='name', where=["name like 'Ma%'", "name like '%r%'"]) self.assertIsInstance(r, list) self.assertEqual(r, [('Marge',)]) r = get_as_list(table, what='name', order='id') @@ -3835,7 +3837,8 @@ def testRecordLiteral(self): def testDate(self): query = self.db.query - for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', + for datestyle in ( + 'ISO', 'Postgres, MDY', 'Postgres, DMY', 'SQL, MDY', 'SQL, DMY', 'German'): self.db.set_parameter('datestyle', datestyle) d = date(2016, 3, 14) @@ -4004,9 +4007,9 @@ def testHstore(self): except pg.DatabaseError: self.skipTest("hstore extension not enabled") d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever', - '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3', - '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"', - 'None': None, 'NULL': 'NULL', 'empty': ''} + '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3', + '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"', + 'None': None, 'NULL': 'NULL', 'empty': ''} q = "select $1::hstore" r = self.db.query(q, (pg.Hstore(d),)).getresult()[0][0] self.assertIsInstance(r, dict) @@ -4082,7 +4085,8 @@ def testDbTypesTypecast(self): r = self.db.query("select '0,0,1'::circle").getresult()[0][0] self.assertIn('circle', dbtypes) self.assertEqual(r, 'Squared Circle: <(0,0),1>') - self.assertEqual(dbtypes.typecast('Impossible', 'circle'), + self.assertEqual( + dbtypes.typecast('Impossible', 'circle'), 'Squared Circle: Impossible') dbtypes.reset_typecast('circle') self.assertIsNone(dbtypes.get_typecast('circle')) diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index a7311391..076415e7 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -10,10 +10,7 @@ These tests do not need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest import json import re diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index 38dc5e1f..c826cb86 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -10,10 +10,7 @@ These tests need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest import tempfile import os diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 84319d4b..65552b5c 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -10,11 +10,7 @@ These tests need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest - +import unittest import warnings from time import sleep from threading import Thread diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index eb37033b..4ff3a02a 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1,10 +1,7 @@ #!/usr/bin/python # -*- coding: utf-8 -*- -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest import pgdb @@ -256,7 +253,8 @@ def test_cursor_with_badly_named_columns(self): self.assertIsInstance(res, tuple) self.assertEqual(res, (1, 2)) self.assertEqual(res._fields, ('abc', '_1')) - cur.execute('select 1 as snake_case, 2 as "CamelCase",' + cur.execute( + 'select 1 as snake_case, 2 as "CamelCase",' ' 3 as "kebap-case", 4 as "_bad", 5 as "0bad", 6 as "bad$"') res = cur.fetchone() self.assertIsInstance(res, tuple) @@ -289,8 +287,8 @@ def test_description_fields(self): con = self._connect() cur = con.cursor() cur.execute("select 123456789::int8 col0," - " 123456.789::numeric(41, 13) as col1," - " 'foobar'::char(39) as col2") + " 123456.789::numeric(41, 13) as col1," + " 'foobar'::char(39) as col2") desc = cur.description self.assertIsInstance(desc, list) self.assertEqual(len(desc), 3) @@ -447,7 +445,8 @@ def test_fetch_2_rows(self): try: cur = con.cursor() cur.execute("set datestyle to iso") - cur.execute("create table %s (" + cur.execute( + "create table %s (" "stringtest varchar," "binarytest bytea," "booltest bool," @@ -465,7 +464,8 @@ def test_fetch_2_rows(self): for s in ('numeric', 'monetary', 'time'): cur.execute("set lc_%s to 'C'" % s) for _i in range(2): - cur.execute("insert into %s values (" + cur.execute( + "insert into %s values (" "%%s,%%s,%%s,%%s,%%s,%%s,%%s," "'%%s'::money,%%s,%%s,%%s,%%s,%%s)" % table, values) cur.execute("select * from %s" % table) @@ -498,7 +498,8 @@ def test_integrity_error(self): cur.execute("create table %s (i int primary key)" % table) cur.execute("insert into %s values (1)" % table) cur.execute("insert into %s values (2)" % table) - self.assertRaises(pgdb.IntegrityError, cur.execute, + self.assertRaises( + pgdb.IntegrityError, cur.execute, "insert into %s values (1)" % table) finally: con.close() @@ -536,7 +537,7 @@ def test_float(self): self.assertTrue(isnan(nan) and not isinf(nan)) self.assertTrue(isinf(inf) and not isnan(inf)) values = [0, 1, 0.03125, -42.53125, nan, inf, -inf, - 'nan', 'inf', '-inf', 'NaN', 'Infinity', '-Infinity'] + 'nan', 'inf', '-inf', 'NaN', 'Infinity', '-Infinity'] table = self.table_prefix + 'booze' con = self._connect() try: @@ -580,8 +581,8 @@ def test_datetime(self): cur = con.cursor() cur.execute("set timezone = UTC") cur.execute("create table %s (" - "d date, t time, ts timestamp," - "tz timetz, tsz timestamptz)" % table) + "d date, t time, ts timestamp," + "tz timetz, tsz timestamptz)" % table) for n in range(3): values = [dt.date(), dt.time(), dt, dt.time(), dt] @@ -599,13 +600,14 @@ def test_datetime(self): pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), pgdb.Timestamp(*(d + t + z))] for datestyle in ('iso', 'postgres, mdy', 'postgres, dmy', - 'sql, mdy', 'sql, dmy', 'german'): + 'sql, mdy', 'sql, dmy', 'german'): cur.execute("set datestyle to %s" % datestyle) if n != 1: cur.execute("select %s,%s,%s,%s,%s", params) row = cur.fetchone() self.assertEqual(row, tuple(values)) - cur.execute("insert into %s" + cur.execute( + "insert into %s" " values (%%s,%%s,%%s,%%s,%%s)" % table, params) cur.execute("select * from %s" % table) d = cur.description @@ -642,10 +644,10 @@ def test_interval(self): param = pgdb.Interval( td.days, 0, 0, td.seconds, td.microseconds) for intervalstyle in ('sql_standard ', 'postgres', - 'postgres_verbose', 'iso_8601'): + 'postgres_verbose', 'iso_8601'): cur.execute("set intervalstyle to %s" % intervalstyle) cur.execute("insert into %s" - " values (%%s)" % table, [param]) + " values (%%s)" % table, [param]) cur.execute("select * from %s" % table) tc = cur.description[0].type_code self.assertEqual(tc, pgdb.DATETIME) @@ -672,9 +674,9 @@ def test_hstore(self): finally: con.close() d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever', 'back\\': '\\slash', - '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3', - '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"', - 'None': None, 'NULL': 'NULL', 'empty': ''} + '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3', + '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"', + 'None': None, 'NULL': 'NULL', 'empty': ''} con = self._connect() try: cur = con.cursor() @@ -699,21 +701,23 @@ def test_uuid(self): self.assertEqual(result, d) def test_insert_array(self): - values = [(None, None), ([], []), ([None], [[None], ['null']]), + values = [ + (None, None), ([], []), ([None], [[None], ['null']]), ([1, 2, 3], [['a', 'b'], ['c', 'd']]), ([20000, 25000, 25000, 30000], - [['breakfast', 'consulting'], ['meeting', 'lunch']]), + [['breakfast', 'consulting'], ['meeting', 'lunch']]), ([0, 1, -1], [['Hello, World!', '"Hi!"'], ['{x,y}', ' x y ']])] table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() cur.execute("create table %s" - " (n smallint, i int[], t text[][])" % table) + " (n smallint, i int[], t text[][])" % table) params = [(n, v[0], v[1]) for n, v in enumerate(values)] # Note that we must explicit casts because we are inserting # empty arrays. Otherwise this is not necessary. - cur.executemany("insert into %s values" + cur.executemany( + "insert into %s values" " (%%d,%%s::int[],%%s::text[][])" % table, params) cur.execute("select i, t from %s order by n" % table) d = cur.description @@ -793,7 +797,7 @@ def test_insert_record(self): def test_select_record(self): value = (1, 25000, 2.5, 'hello', 'Hello World!', 'Hello, World!', - '(test)', '(x,y)', ' x y ', 'null', None) + '(test)', '(x,y)', ' x y ', 'null', None) con = self._connect() try: cur = con.cursor() @@ -829,7 +833,8 @@ def test_custom_type(self): try: cur = con.cursor() params = (1, object()) # an object that cannot be handled - self.assertRaises(pgdb.InterfaceError, cur.execute, + self.assertRaises( + pgdb.InterfaceError, cur.execute, "insert into %s values (%%s,%%s)" % table, params) finally: con.close() @@ -1319,8 +1324,8 @@ def test_set_row_factory_size(self): info = pgdb._row_factory.cache_info() self.assertEqual(info.maxsize, maxsize) self.assertEqual(info.hits + info.misses, 6) - self.assertEqual(info.hits, - 0 if maxsize is not None and maxsize < 2 else 4) + self.assertEqual( + info.hits, 0 if maxsize is not None and maxsize < 2 else 4) def test_memory_leaks(self): ids = set() diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 7fdca2c0..51d296ef 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -10,10 +10,7 @@ These tests need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest try: from collections.abc import Iterable @@ -139,7 +136,7 @@ def setUpClass(cls): cur.execute("set client_min_messages=warning") cur.execute("drop table if exists copytest cascade") cur.execute("create table copytest (" - "id smallint primary key, name varchar(64))") + "id smallint primary key, name varchar(64))") cur.close() con.commit() cur = con.cursor() @@ -519,8 +516,8 @@ def test_columns(self): self.assertEqual(ret, self.data_text) ret = ''.join(self.copy_to(columns=['id', 'name'])) self.assertEqual(ret, self.data_text) - self.assertRaises(pgdb.ProgrammingError, self.copy_to, - columns=['id', 'age']) + self.assertRaises( + pgdb.ProgrammingError, self.copy_to, columns=['id', 'age']) def test_csv(self): ret = self.copy_to(format='csv') @@ -552,10 +549,11 @@ def test_binary_with_unicode(self): format='binary', decode=True) def test_query(self): - self.assertRaises(ValueError, self.cursor.copy_to, None, + self.assertRaises( + ValueError, self.cursor.copy_to, None, "select name from copytest", columns='noname') - ret = self.cursor.copy_to(None, - "select name||'!' from copytest where id=1941") + ret = self.cursor.copy_to( + None, "select name||'!' from copytest where id=1941") self.assertIsInstance(ret, Iterable) rows = list(ret) self.assertEqual(len(rows), 1) diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 10943359..d1295a6b 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -3,10 +3,7 @@ from __future__ import print_function -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest from pg import DB from pgdb import connect From b14a09b7c9a4b38cc8f7c99d685a327d04a323ec Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 19 Apr 2020 21:40:20 +0200 Subject: [PATCH 004/194] Use new syntax (set literals and dict comprehensions) --- pg.py | 37 +++++++++++++++++--------------- pgdb.py | 6 +++--- tests/test_classic_connection.py | 2 +- tests/test_classic_dbwrapper.py | 33 +++++++++++++++------------- tests/test_dbapi20.py | 10 ++++----- 5 files changed, 46 insertions(+), 42 deletions(-) diff --git a/pg.py b/pg.py index e2ef2aa4..33a3ae75 100644 --- a/pg.py +++ b/pg.py @@ -226,8 +226,8 @@ def dst(self, dt): # time zones used in Postgres timestamptz output _timezones = dict(CET='+0100', EET='+0200', EST='-0500', - GMT='+0000', HST='-1000', MET='+0100', MST='-0700', - UCT='+0000', UTC='+0000', WET='+0000') + GMT='+0000', HST='-1000', MET='+0100', MST='-0700', + UCT='+0000', UTC='+0000', WET='+0000') def _timezone_as_offset(tz): @@ -693,7 +693,7 @@ def format_query(self, command, values=None, types=None, inline=False): len(types) != len(values)): raise TypeError('The values and types do not match') literals = [add(value, typ) - for value, typ in zip(values, types)] + for value, typ in zip(values, types)] else: literals = [add(value) for value in values] command %= tuple(literals) @@ -712,18 +712,18 @@ def format_query(self, command, values=None, types=None, inline=False): values = used_values if inline: adapt = self.adapt_inline - literals = dict((key, adapt(value)) - for key, value in values.items()) + literals = {key: adapt(value) + for key, value in values.items()} else: add = params.add if types: if not isinstance(types, dict): raise TypeError('The values and types do not match') - literals = dict((key, add(values[key], types.get(key))) - for key in sorted(values)) + literals = {key: add(values[key], types.get(key)) + for key in sorted(values)} else: - literals = dict((key, add(values[key])) - for key in sorted(values)) + literals = {key: add(values[key]) + for key in sorted(values)} command %= literals else: raise TypeError('The values must be passed as tuple, list or dict') @@ -827,7 +827,7 @@ def cast_timestamp(value, connection): if len(value[3]) > 4: return datetime.max fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] + '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] else: if len(value[0]) > 10: return datetime.max @@ -1159,8 +1159,8 @@ class DbTypes(dict): information on the associated database type. """ - _num_types = frozenset('int float num money' - ' int2 int4 int8 float4 float8 numeric money'.split()) + _num_types = frozenset('int float num money int2 int4 int8' + ' float4 float8 numeric money'.split()) def __init__(self, db): """Initialize type cache for connection.""" @@ -1768,7 +1768,7 @@ def get_parameter(self, parameter): if param == 'all': q = 'SHOW ALL' values = self.db.query(q).getresult() - values = dict(value[:2] for value in values) + values = {value[0]: value[1] for value in values} break if isinstance(values, dict): params[param] = key @@ -1823,12 +1823,14 @@ def set_parameter(self, parameter, value=None, local=False): if len(value) == 1: value = value.pop() if not(value is None or isinstance(value, basestring)): - raise ValueError('A single value must be specified' + raise ValueError( + 'A single value must be specified' ' when parameter is a set') parameter = dict.fromkeys(parameter, value) elif isinstance(parameter, dict): if value is not None: - raise ValueError('A value must not be specified' + raise ValueError( + 'A value must not be specified' ' when parameter is a dictionary') else: raise TypeError( @@ -1843,7 +1845,8 @@ def set_parameter(self, parameter, value=None, local=False): raise TypeError('Invalid parameter') if param == 'all': if value is not None: - raise ValueError('A value must ot be specified' + raise ValueError( + 'A value must ot be specified' " when parameter is 'all'") params = {'all': None} break @@ -1886,7 +1889,7 @@ def query(self, command, *args): return self.db.query(command) def query_formatted(self, command, - parameters=None, types=None, inline=False): + parameters=None, types=None, inline=False): """Execute a formatted SQL command string. Similar to query, but using Python format placeholders of the form diff --git a/pgdb.py b/pgdb.py index bf36d8e4..cdfb171e 100644 --- a/pgdb.py +++ b/pgdb.py @@ -281,8 +281,8 @@ def dst(self, dt): # time zones used in Postgres timestamptz output _timezones = dict(CET='+0100', EET='+0200', EST='-0500', - GMT='+0000', HST='-1000', MET='+0100', MST='-0700', - UCT='+0000', UTC='+0000', WET='+0000') + GMT='+0000', HST='-1000', MET='+0100', MST='-0700', + UCT='+0000', UTC='+0000', WET='+0000') def _timezone_as_offset(tz): @@ -521,7 +521,7 @@ def cast_interval(value): raise ValueError('Cannot parse interval: %s' % value) days += 365 * years + 30 * mons return timedelta(days=days, hours=hours, minutes=mins, - seconds=secs, microseconds=usecs) + seconds=secs, microseconds=usecs) class Typecasts(dict): diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 01559c0d..a56fc026 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1685,7 +1685,7 @@ def testInserttableFromTupleOfLists(self): self.assertEqual(self.get_back(), self.data) def testInserttableFromSetofTuples(self): - data = set(row for row in self.data) + data = {row for row in self.data} try: self.c.inserttable('test', data) except TypeError as e: diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index abec52c1..d9362889 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -19,6 +19,7 @@ import pg # the module under test +from collections import OrderedDict from decimal import Decimal from datetime import date, time, datetime, timedelta from uuid import UUID @@ -52,8 +53,6 @@ except NameError: # Python >= 3.0 unicode = str -from collections import OrderedDict - if str is bytes: # noinspection PyUnresolvedReferences from StringIO import StringIO else: # Python >= 3.0 @@ -660,10 +659,10 @@ def testGetParameter(self): self.assertEqual(r, ['hex', 'C']) r = f(('standard_conforming_strings', 'datestyle', 'bytea_output')) self.assertEqual(r, ['on', 'ISO, YMD', 'hex']) - r = f(set(['bytea_output', 'lc_monetary'])) + r = f({'bytea_output', 'lc_monetary'}) self.assertIsInstance(r, dict) self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'}) - r = f(set(['Bytea_Output', ' LC_Monetary '])) + r = f({'Bytea_Output', ' LC_Monetary '}) self.assertIsInstance(r, dict) self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'}) s = dict.fromkeys(('bytea_output', 'lc_monetary')) @@ -720,13 +719,15 @@ def testSetParameter(self): f(('escape_string_warning', 'standard_conforming_strings'), 'off') self.assertEqual(g('escape_string_warning'), 'off') self.assertEqual(g('standard_conforming_strings'), 'off') - f(set(['escape_string_warning', 'standard_conforming_strings']), 'on') + f({'escape_string_warning', 'standard_conforming_strings'}, 'on') self.assertEqual(g('escape_string_warning'), 'on') self.assertEqual(g('standard_conforming_strings'), 'on') - self.assertRaises(ValueError, f, set(['escape_string_warning', - 'standard_conforming_strings']), ['off', 'on']) - f(set(['escape_string_warning', 'standard_conforming_strings']), - ['off', 'off']) + self.assertRaises( + ValueError, f, + {'escape_string_warning', 'standard_conforming_strings'}, + ['off', 'on']) + f({'escape_string_warning', 'standard_conforming_strings'}, + ['off', 'off']) self.assertEqual(g('escape_string_warning'), 'off') self.assertEqual(g('standard_conforming_strings'), 'off') f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'}) @@ -769,7 +770,7 @@ def testResetParameter(self): f('standard_conforming_strings', not_scs) self.assertEqual(g('escape_string_warning'), not_esw) self.assertEqual(g('standard_conforming_strings'), not_scs) - f(set(['escape_string_warning', 'standard_conforming_strings'])) + f({'escape_string_warning', 'standard_conforming_strings'}) self.assertEqual(g('escape_string_warning'), esw) self.assertEqual(g('standard_conforming_strings'), scs) db.close() @@ -2880,8 +2881,8 @@ def testGetAsList(self): from_table = '(select lower(name) as n2 from "%s") as t2' % table r = get_as_list(from_table) self.assertIsInstance(r, list) - r = set(row[0] for row in r) - expected = set(row[1].lower() for row in names) + r = {row[0] for row in r} + expected = {row[1].lower() for row in names} self.assertEqual(r, expected) r = get_as_list(from_table, order='n2', scalar=True) self.assertIsInstance(r, list) @@ -3030,7 +3031,7 @@ def testGetAsDict(self): self.assertIsInstance(r, OrderedDict) self.assertEqual(len(r), 0) # test with unordered query - expected = dict((row[0], row[1:]) for row in colors) + expected = {row[0]: row[1:] for row in colors} r = get_as_dict(table, order=False) self.assertIsInstance(r, dict) self.assertEqual(r, expected) @@ -3382,7 +3383,8 @@ def testInsertGetJsonb(self): def testArray(self): returns_arrays = pg.get_array() - self.createTable('arraytest', + self.createTable( + 'arraytest', 'id smallint, i2 smallint[], i4 integer[], i8 bigint[],' ' d numeric[], f4 real[], f8 double precision[], m money[],' ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]') @@ -3406,7 +3408,8 @@ def testArray(self): long_decimal = decimal('12345671234.5') odd_money = decimal('1234567123.25') t, f = (True, False) if pg.get_bool() else ('t', 'f') - data = dict(id=42, i2=[42, 1234, None, 0, -1], + data = dict( + id=42, i2=[42, 1234, None, 0, -1], i4=[42, 123456789, None, 0, 1, -1], i8=[long(42), long(123456789123456789), None, long(0), long(1), long(-1)], diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 4ff3a02a..00280eec 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -35,8 +35,6 @@ except NameError: # Python >= 3.0 long = int -from collections import OrderedDict - class PgBitString: """Test object with a PostgreSQL representation as Bit String.""" @@ -161,8 +159,8 @@ def test_row_factory(self): class TestCursor(pgdb.Cursor): def row_factory(self, row): - return dict(('column %s' % desc[0], value) - for desc, value in zip(self.description, row)) + return {'column %s' % desc[0]: value + for desc, value in zip(self.description, row)} con = self._connect() con.cursor_type = TestCursor @@ -188,8 +186,8 @@ class TestCursor(pgdb.Cursor): def build_row_factory(self): keys = [desc[0] for desc in self.description] - return lambda row: dict((key, value) - for key, value in zip(keys, row)) + return lambda row: { + key: value for key, value in zip(keys, row)} con = self._connect() con.cursor_type = TestCursor From f983288726b1d77378c18283ad0a9450e23248c2 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 20 Apr 2020 23:46:41 +0200 Subject: [PATCH 005/194] Slightly improve building under Windows --- setup.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/setup.py b/setup.py index fc68c176..1eb601d9 100755 --- a/setup.py +++ b/setup.py @@ -171,32 +171,29 @@ def finalize_options(self): "The installed PostgreSQL version" " does not support ssl info functions.") if sys.platform == 'win32': - bits = platform.architecture()[0] - if bits == '64bit': # we need to find libpq64 - for path in os.environ['PATH'].split(os.pathsep) + [ - r'C:\Program Files\PostgreSQL\libpq64']: - library_dir = os.path.join(path, 'lib') - if not os.path.isdir(library_dir): - continue - lib = os.path.join(library_dir, 'libpqdll.') - if not (os.path.exists(lib + 'lib') - or os.path.exists(lib + 'a')): - continue - include_dir = os.path.join(path, 'include') - if not os.path.isdir(include_dir): - continue - if library_dir not in library_dirs: - library_dirs.insert(1, library_dir) + libraries[0] = 'lib' + libraries[0] + for path in os.environ['PATH'].split(os.pathsep): + library_dir = os.path.join(path, 'lib') + if not os.path.isdir(library_dir): + continue + if os.path.exists( + os.path.join(library_dir, libraries[0] + 'dll.lib')): + libraries[0] += 'dll' + elif not os.path.exists( + os.path.join(library_dir, libraries[0] + '.lib')): + continue + if library_dir not in library_dirs: + library_dirs.insert(1, library_dir) + include_dir = os.path.join(path, 'include') + if os.path.isdir(include_dir): if include_dir not in include_dirs: include_dirs.insert(1, include_dir) - libraries[0] += 'dll' # libpqdll instead of libpq - break + break compiler = self.get_compiler() if compiler == 'mingw32': # MinGW - if bits == '64bit': # needs MinGW-w64 + if platform.architecture()[0] == '64bit': # needs MinGW-w64 define_macros.append(('MS_WIN64', None)) elif compiler == 'msvc': # Microsoft Visual C++ - libraries[0] = 'lib' + libraries[0] extra_compile_args[1:] = [ '-J', '-W3', '-WX', '-Dinline=__inline'] # needed for MSVC 9 From f0aef89838c4bcd0162aba1f7d8f2eb2b19f7429 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 21 Apr 2020 11:13:21 +0200 Subject: [PATCH 006/194] Simplify building under Windows a bit more --- setup.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index 1eb601d9..049105d3 100755 --- a/setup.py +++ b/setup.py @@ -172,23 +172,9 @@ def finalize_options(self): " does not support ssl info functions.") if sys.platform == 'win32': libraries[0] = 'lib' + libraries[0] - for path in os.environ['PATH'].split(os.pathsep): - library_dir = os.path.join(path, 'lib') - if not os.path.isdir(library_dir): - continue - if os.path.exists( - os.path.join(library_dir, libraries[0] + 'dll.lib')): - libraries[0] += 'dll' - elif not os.path.exists( - os.path.join(library_dir, libraries[0] + '.lib')): - continue - if library_dir not in library_dirs: - library_dirs.insert(1, library_dir) - include_dir = os.path.join(path, 'include') - if os.path.isdir(include_dir): - if include_dir not in include_dirs: - include_dirs.insert(1, include_dir) - break + if os.path.exists(os.path.join( + library_dirs[1], libraries[0] + 'dll.lib')): + libraries[0] += 'dll' compiler = self.get_compiler() if compiler == 'mingw32': # MinGW if platform.architecture()[0] == '64bit': # needs MinGW-w64 From f4b022373069047b542a7b444ab7bdf1a6948a2f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 21 Apr 2020 12:22:15 +0200 Subject: [PATCH 007/194] Add pqlib_info and memory_size features --- docs/contents/changelog.rst | 7 ++++++- docs/contents/install.rst | 9 +++++++++ docs/contents/pg/module.rst | 18 ++++++++++++++++++ docs/contents/pg/query.rst | 15 +++++++++++++++ pgmodule.c | 18 +++++++++++++++++- pgquery.c | 17 +++++++++++++++++ setup.py | 18 +++++++++++++++++- tests/test_classic_connection.py | 21 ++++++++++++++++++--- tests/test_classic_functions.py | 6 ++++++ 9 files changed, 123 insertions(+), 6 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 3af7d802..7e2731c1 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -3,7 +3,12 @@ ChangeLog Version 5.2 (to be released) ---------------------------- -- We now Python version 2.7 or 3.5 and newer +- We now require Python version 2.7 or 3.5 and newer +- Changes to the classic PyGreSQL module (pg): + - New module level function `get_pqlib_version()` that gets the version + of the pqlib used by PyGreSQL (needs PostgreSQL >= 9.1 on the client). + - New query method `memsize()` that gets the memory size allocated by + the query (needs PostgreSQL >= 12 on the client). Version 5.1.2 (2020-04-19) diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 126e791b..1843ef41 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -104,6 +104,11 @@ You can find out all possible build options with:: Alternatively, you can also use the corresponding C preprocessor macros like ``SSL_INFO`` directly (see the next section). +Note that if you build PyGreSQL with support for newer features that are not +available in the PQLib installed on the runtime system, you may get an error +when importing PyGreSQL, since these features are missing in the shared library +which will prevent Python from loading it. + Compiling Manually ~~~~~~~~~~~~~~~~~~ @@ -147,7 +152,9 @@ Stand-Alone -DDIRECT_ACCESS direct access methods -DLARGE_OBJECTS large object support -DESCAPING_FUNCS support for newer escaping functions + -DPQLIB_INFO support PQLib information -DSSL_INFO support SSL information + -DMEMORY_SIZE support memory size function On some systems you may need to include ``-lcrypt`` in the list of libraries to make it compile. @@ -193,7 +200,9 @@ Built-in to Python interpreter -DDIRECT_ACCESS direct access methods -DLARGE_OBJECTS large object support -DESCAPING_FUNCS support for newer escaping functions + -DPQLIB_INFO support PQLib information -DSSL_INFO support SSL information + -DMEMORY_SIZE support memory size function On some systems you may need to include ``-lcrypt`` in the list of libraries to make it compile. diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index 15b1824e..1a92f283 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -63,6 +63,24 @@ Example:: con3 = pg.connect('host=myhost user=bob dbname=testdb connect_timeout=10') con4 = pg.connect('postgresql://bob@myhost/testdb?connect_timeout=10') + +get_pqlib_version -- get the version of libpq +--------------------------------------------- + +.. function:: get_pqlib_version() + + Get the version of libpq that is being used by PyGreSQL + + :returns: the version of libpq + :rtype: int + :raises TypeError: too many arguments + +The number is formed by converting the major, minor, and revision numbers of +the libpq version into two-decimal-digit numbers and appending them together. +For example, version 9.1.2 will be returned as 90102. + +.. versionadded:: 5.2 (needs PostgreSQL >= 9.1) + get/set_defhost -- default server host [DV] ------------------------------------------- diff --git a/docs/contents/pg/query.rst b/docs/contents/pg/query.rst index 2d3b7abb..311efaff 100644 --- a/docs/contents/pg/query.rst +++ b/docs/contents/pg/query.rst @@ -380,3 +380,18 @@ This method returns the number of tuples in the query result. .. deprecated:: 5.1 You can use the normal :func:`len` function instead. + +memsize -- return number of bytes allocated by query result +----------------------------------------------------------- + +.. method:: Query.memsize() + + Return number of bytes allocated by query result + + :returns: number of bytes allocated for the query result + :rtype: int + :raises TypeError: Too many arguments. + +This method returns the number of bytes allocated for the query result. + +.. versionadded:: 5.2 (needs PostgreSQL >= 12) diff --git a/pgmodule.c b/pgmodule.c index 3a1c70be..20c47993 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -280,6 +280,19 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) return (PyObject *) conn_obj; } +#ifdef PQLIB_INFO + +/* Get version of libpq that is being used */ +static char pg_get_pqlib_version__doc__[] = +"get_pqlib_version() -- get the version of libpq that is being used"; + +static PyObject * +pg_get_pqlib_version(PyObject *self, PyObject *noargs) { + return PyLong_FromLong(PQlibVersion()); +} + +#endif /* PQLIB_INFO */ + /* Escape string */ static char pg_escape_string__doc__[] = "escape_string(string) -- escape a string for use within SQL"; @@ -1176,7 +1189,6 @@ static struct PyMethodDef pg_methods[] = { METH_VARARGS|METH_KEYWORDS, pg_cast_record__doc__}, {"cast_hstore", (PyCFunction) pg_cast_hstore, METH_O, pg_cast_hstore__doc__}, - #ifdef DEFAULT_VARS {"get_defhost", pg_get_defhost, METH_NOARGS, pg_get_defhost__doc__}, {"set_defhost", pg_set_defhost, METH_VARARGS, pg_set_defhost__doc__}, @@ -1190,6 +1202,10 @@ static struct PyMethodDef pg_methods[] = { {"set_defuser", pg_set_defuser, METH_VARARGS, pg_set_defuser__doc__}, {"set_defpasswd", pg_set_defpasswd, METH_VARARGS, pg_set_defpasswd__doc__}, #endif /* DEFAULT_VARS */ +#ifdef PQLIB_INFO + {"get_pqlib_version", (PyCFunction) pg_get_pqlib_version, + METH_NOARGS, pg_get_pqlib_version__doc__}, +#endif /* PQLIB_INFO */ {NULL, NULL} /* sentinel */ }; diff --git a/pgquery.c b/pgquery.c index d90db5dc..54b1466f 100644 --- a/pgquery.c +++ b/pgquery.c @@ -149,6 +149,19 @@ query_next(queryObject *self, PyObject *noargs) return row_tuple; } +#ifdef MEMORY_SIZE + +/* Get number of bytes allocated for PGresult object */ +static char query_memsize__doc__[] = +"memsize() -- return number of bytes allocated by query result"; +static PyObject * +query_memsize(queryObject *self, PyObject *noargs) +{ + return PyLong_FromSize_t(PQresultMemorySize(self->result)); +} + +#endif /* MEMORY_SIZE */ + /* Get number of rows. */ static char query_ntuples__doc__[] = "ntuples() -- return number of tuples returned by query"; @@ -706,6 +719,10 @@ static struct PyMethodDef query_methods[] = { METH_NOARGS, query_listfields__doc__}, {"ntuples", (PyCFunction) query_ntuples, METH_NOARGS, query_ntuples__doc__}, +#ifdef MEMORY_SIZE + {"memsize", (PyCFunction) query_memsize, + METH_NOARGS, query_memsize__doc__}, +#endif /* MEMORY_SIZE */ {NULL, NULL} }; diff --git a/setup.py b/setup.py index 049105d3..1bd3efa7 100755 --- a/setup.py +++ b/setup.py @@ -140,7 +140,9 @@ def initialize_options(self): self.large_objects = None self.default_vars = None self.escaping_funcs = None + self.pqlib_info = None self.ssl_info = None + self.memory_size = None if pg_version < (9, 0): warnings.warn( "PyGreSQL does not support the installed PostgreSQL version.") @@ -163,13 +165,27 @@ def finalize_options(self): (warnings.warn if self.escaping_funcs is None else sys.exit)( "The installed PostgreSQL version" " does not support the newer string escaping functions.") + if self.pqlib_info is None or self.pqlib_info: + if pg_version >= (9, 1): + define_macros.append(('PQLIB_INFO', None)) + else: + (warnings.warn if self.pqlib_info is None else sys.exit)( + "The installed PostgreSQL version" + " does not support PQLib info functions.") if self.ssl_info is None or self.ssl_info: if pg_version >= (9, 5): define_macros.append(('SSL_INFO', None)) else: (warnings.warn if self.ssl_info is None else sys.exit)( "The installed PostgreSQL version" - " does not support ssl info functions.") + " does not support SSL info functions.") + if self.memory_size is None or self.memory_size: + if pg_version >= (12, 0): + define_macros.append(('MEMORY_SIZE', None)) + else: + (warnings.warn if self.memory_size is None else sys.exit)( + "The installed PostgreSQL version" + " does not support the memory size function.") if sys.platform == 'win32': libraries[0] = 'lib' + libraries[0] if os.path.exists(os.path.join( diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index a56fc026..e7ccbd50 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -218,9 +218,10 @@ def testMethodQueryEmpty(self): def testAllQueryMembers(self): query = self.connection.query("select true where false") members = ''' - dictiter dictresult fieldname fieldnum getresult listfields - namediter namedresult ntuples one onedict onenamed onescalar - scalariter scalarresult single singledict singlenamed singlescalar + dictiter dictresult fieldname fieldnum getresult + listfields memsize namediter namedresult ntuples + one onedict onenamed onescalar scalariter scalarresult + single singledict singlenamed singlescalar '''.split() query_members = [a for a in dir(query) if not a.startswith('__') @@ -683,6 +684,20 @@ def testQueryWithOids(self): self.assertIsInstance(r, str) self.assertEqual(r, '5') + def testMemSize(self): + if pg.get_pqlib_version() < 120000: + self.skipTest("pqlib does not support memsize()") + query = self.c.query + q = query("select repeat('foo!', 8)") + size = q.memsize() + self.assertIsInstance(size, long) + self.assertGreaterEqual(size, 32) + self.assertLess(size, 8000) + q = query("select repeat('foo!', 2000)") + size = q.memsize() + self.assertGreaterEqual(size, 8000) + self.assertLess(size, 16000) + class TestUnicodeQueries(unittest.TestCase): """Test unicode strings as queries via a basic pg connection.""" diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 076415e7..5218b98d 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -121,6 +121,12 @@ def testDefBase(self): pg.set_defbase(d0) self.assertEqual(pg.get_defbase(), d0) + def testPqlibVersion(self): + v = pg.get_pqlib_version() + self.assertIsInstance(v, long) + self.assertGreater(v, 90000) + self.assertLess(v, 130000) + class TestParseArray(unittest.TestCase): """Test the array parser.""" From 235dae62dbbbcd6be15d5d9c0e7aaad5c04682eb Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 21 Apr 2020 12:58:25 +0200 Subject: [PATCH 008/194] Use PostgreSQL 12 in Travis CI --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 4bb9dd99..99d2d044 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,7 @@ install: script: python setup.py test addons: - postgresql: "10" + postgresql: "12" services: - postgresql From ceee0cbddd2d812e02cc72fd6d0e913c46249cf2 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 21 Apr 2020 13:03:10 +0200 Subject: [PATCH 009/194] Use Bionic to have PostgreSQL 12 in Travis CI --- .travis.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.travis.yml b/.travis.yml index 99d2d044..53e60950 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,9 @@ # Travis CI configuration # see https://docs.travis-ci.com/user/languages/python +os: linux +dist: bionic + language: python python: @@ -17,6 +20,11 @@ script: python setup.py test addons: postgresql: "12" + apt: + packages: + - postgresql-12 + - postgresql-client-12 + - postgresql-server-dev-12 services: - postgresql From 8d6bfefb8d62b1aa39e103ceaed2d97a49daed10 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 21 Apr 2020 13:16:57 +0200 Subject: [PATCH 010/194] Try using PostgreSQL 12 with Xenial in Travis CI --- .travis.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 53e60950..5ea3b622 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,9 +1,6 @@ # Travis CI configuration # see https://docs.travis-ci.com/user/languages/python -os: linux -dist: bionic - language: python python: @@ -21,9 +18,11 @@ script: python setup.py test addons: postgresql: "12" apt: + sources: + - sourceline: deb http://apt.postgresql.org/pub/repos/apt/ xenial-pgdg main 12 + key_url: https://www.postgresql.org/media/keys/ACCC4CF8.asc packages: - postgresql-12 - - postgresql-client-12 - postgresql-server-dev-12 services: From f79fb722b79403267eee6c76a5c321176511fd36 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 21 Apr 2020 13:41:06 +0200 Subject: [PATCH 011/194] PostgreSQL 12 is running on different port in Travis --- .travis.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.travis.yml b/.travis.yml index 5ea3b622..483400e5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -25,8 +25,13 @@ addons: - postgresql-12 - postgresql-server-dev-12 +env: + global: + - PGPORT=5433 + services: - postgresql before_script: + - echo "dbport = 5433" > tests/LOCAL_PyGreSQL.py - psql -U postgres -c 'create database unittest' From cde9d4c75e3f82ca2a8e29c8956a77fc63f8b447 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 21 Apr 2020 13:58:19 +0200 Subject: [PATCH 012/194] Switch back to testing with PostgreSQL 10 on Travis Using a newer version causes currently too much trouble. --- .travis.yml | 14 +------------- tests/test_classic_connection.py | 5 ++++- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/.travis.yml b/.travis.yml index 483400e5..4bb9dd99 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,22 +16,10 @@ install: script: python setup.py test addons: - postgresql: "12" - apt: - sources: - - sourceline: deb http://apt.postgresql.org/pub/repos/apt/ xenial-pgdg main 12 - key_url: https://www.postgresql.org/media/keys/ACCC4CF8.asc - packages: - - postgresql-12 - - postgresql-server-dev-12 - -env: - global: - - PGPORT=5433 + postgresql: "10" services: - postgresql before_script: - - echo "dbport = 5433" > tests/LOCAL_PyGreSQL.py - psql -U postgres -c 'create database unittest' diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index e7ccbd50..4755d222 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -223,7 +223,10 @@ def testAllQueryMembers(self): one onedict onenamed onescalar scalariter scalarresult single singledict singlenamed singlescalar '''.split() - query_members = [a for a in dir(query) + if pg.get_pqlib_version() < 120000: + members.remove('memsize') + query_members = [ + a for a in dir(query) if not a.startswith('__') and a != 'next'] # this is only needed in Python 2 self.assertEqual(members, query_members) From 5f70cda8d982f71c0c4d91f87d8c6978c234d675 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 21 Apr 2020 14:10:21 +0200 Subject: [PATCH 013/194] More PEP8 style fixes --- tests/test_classic_connection.py | 221 +++++++++++++++++------------ tests/test_classic_functions.py | 60 ++++---- tests/test_classic_notification.py | 3 +- tests/test_dbapi20_copy.py | 24 ++-- 4 files changed, 178 insertions(+), 130 deletions(-) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 4755d222..d36c210a 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -119,7 +119,8 @@ def testAllConnectAttributes(self): attributes = '''backend_pid db error host options port protocol_version server_version socket ssl_attributes ssl_in_use status user'''.split() - connection_attributes = [a for a in dir(self.connection) + connection_attributes = [ + a for a in dir(self.connection) if not a.startswith('__') and not self.is_method(a)] self.assertEqual(attributes, connection_attributes) @@ -132,7 +133,8 @@ def testAllConnectMethods(self): prepare putline query query_prepared reset set_cast_hook set_notice_receiver source transaction '''.split() - connection_methods = [a for a in dir(self.connection) + connection_methods = [ + a for a in dir(self.connection) if not a.startswith('__') and self.is_method(a)] self.assertEqual(methods, connection_methods) @@ -356,9 +358,10 @@ def testModuleName(self): def testStr(self): q = ("select 1 as a, 'hello' as h, 'w' as world" - " union select 2, 'xyz', 'uvw'") + " union select 2, 'xyz', 'uvw'") r = self.c.query(q) - self.assertEqual(str(r), + self.assertEqual( + str(r), 'a| h |world\n' '-+-----+-----\n' '1|hello|w \n' @@ -503,14 +506,14 @@ def testGet3Rows(self): def testGet3DictRows(self): q = ("select 3 as alias3" - " union select 1 union select 2 order by 1") + " union select 1 union select 2 order by 1") result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}] r = self.c.query(q).dictresult() self.assertEqual(r, result) def testGet3NamedRows(self): q = ("select 3 as alias3" - " union select 1 union select 2 order by 1") + " union select 1 union select 2 order by 1") result = [(1,), (2,), (3,)] r = self.c.query(q).namedresult() self.assertEqual(r, result) @@ -553,17 +556,16 @@ def testBigGetresult(self): def testListfields(self): q = ('select 0 as a, 0 as b, 0 as c,' - ' 0 as c, 0 as b, 0 as a,' - ' 0 as lowercase, 0 as UPPERCASE,' - ' 0 as MixedCase, 0 as "MixedCase",' - ' 0 as a_long_name_with_underscores,' - ' 0 as "A long name with Blanks"') + ' 0 as c, 0 as b, 0 as a,' + ' 0 as lowercase, 0 as UPPERCASE,' + ' 0 as MixedCase, 0 as "MixedCase",' + ' 0 as a_long_name_with_underscores,' + ' 0 as "A long name with Blanks"') r = self.c.query(q).listfields() self.assertIsInstance(r, tuple) result = ('a', 'b', 'c', 'c', 'b', 'a', - 'lowercase', 'uppercase', 'mixedcase', 'MixedCase', - 'a_long_name_with_underscores', - 'A long name with Blanks') + 'lowercase', 'uppercase', 'mixedcase', 'MixedCase', + 'a_long_name_with_underscores', 'A long name with Blanks') self.assertEqual(r, result) def testFieldname(self): @@ -594,12 +596,12 @@ def testNtuples(self): # deprecated self.assertIsInstance(r, int) self.assertEqual(r, 0) q = ("select 1 as a, 2 as b, 3 as c, 4 as d" - " union select 5 as a, 6 as b, 7 as c, 8 as d") + " union select 5 as a, 6 as b, 7 as c, 8 as d") r = self.c.query(q).ntuples() self.assertIsInstance(r, int) self.assertEqual(r, 2) q = ("select 1 union select 2 union select 3" - " union select 4 union select 5 union select 6") + " union select 4 union select 5 union select 6") r = self.c.query(q).ntuples() self.assertIsInstance(r, int) self.assertEqual(r, 6) @@ -608,10 +610,10 @@ def testLen(self): q = "select 1 where false" self.assertEqual(len(self.c.query(q)), 0) q = ("select 1 as a, 2 as b, 3 as c, 4 as d" - " union select 5 as a, 6 as b, 7 as c, 8 as d") + " union select 5 as a, 6 as b, 7 as c, 8 as d") self.assertEqual(len(self.c.query(q)), 2) q = ("select 1 union select 2 union select 3" - " union select 4 union select 5 union select 6") + " union select 4 union select 5 union select 6") self.assertEqual(len(self.c.query(q)), 6) def testQuery(self): @@ -920,41 +922,53 @@ def testQueryWithIntParams(self): self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)]) self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)]) self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)]) - self.assertEqual(query("select 1+$1::numeric", [1]).getresult(), - [(Decimal('2'),)]) - self.assertEqual(query("select 1, $1::integer", (2,) - ).getresult(), [(1, 2)]) - self.assertEqual(query("select 1 union select $1::integer", (2,) - ).getresult(), [(1,), (2,)]) - self.assertEqual(query("select $1::integer+$2", (1, 2) - ).getresult(), [(3,)]) - self.assertEqual(query("select $1::integer+$2", [1, 2] - ).getresult(), [(3,)]) - self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6)) - ).getresult(), [(15,)]) + self.assertEqual( + query("select 1+$1::numeric", [1]).getresult(), [(Decimal('2'),)]) + self.assertEqual( + query("select 1, $1::integer", (2,)).getresult(), [(1, 2)]) + self.assertEqual( + query("select 1 union select $1::integer", (2,)).getresult(), + [(1,), (2,)]) + self.assertEqual( + query("select $1::integer+$2", (1, 2)).getresult(), [(3,)]) + self.assertEqual( + query("select $1::integer+$2", [1, 2]).getresult(), [(3,)]) + self.assertEqual( + query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))).getresult(), + [(15,)]) def testQueryWithStrParams(self): query = self.c.query - self.assertEqual(query("select $1||', world!'", ('Hello',) - ).getresult(), [('Hello, world!',)]) - self.assertEqual(query("select $1||', world!'", ['Hello'] - ).getresult(), [('Hello, world!',)]) - self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'), - ).getresult(), [('Hello, world!',)]) - self.assertEqual(query("select $1::text", ('Hello, world!',) - ).getresult(), [('Hello, world!',)]) - self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world') - ).getresult(), [('Hello', 'world')]) - self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world'] - ).getresult(), [('Hello', 'world')]) - self.assertEqual(query("select $1::text union select $2::text", - ('Hello', 'world')).getresult(), [('Hello',), ('world',)]) + self.assertEqual( + query("select $1||', world!'", ('Hello',)).getresult(), + [('Hello, world!',)]) + self.assertEqual( + query("select $1||', world!'", ['Hello']).getresult(), + [('Hello, world!',)]) + self.assertEqual( + query("select $1||', '||$2||'!'", ('Hello', 'world')).getresult(), + [('Hello, world!',)]) + self.assertEqual( + query("select $1::text", ('Hello, world!',)).getresult(), + [('Hello, world!',)]) + self.assertEqual( + query("select $1::text,$2::text", ('Hello', 'world')).getresult(), + [('Hello', 'world')]) + self.assertEqual( + query("select $1::text,$2::text", ['Hello', 'world']).getresult(), + [('Hello', 'world')]) + self.assertEqual( + query("select $1::text union select $2::text", + ('Hello', 'world')).getresult(), + [('Hello',), ('world',)]) try: query("select 'wörld'") except (pg.DataError, pg.NotSupportedError): self.skipTest('database does not support utf8') - self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', - 'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)]) + self.assertEqual( + query("select $1||', '||$2||'!'", + ('Hello', 'w\xc3\xb6rld')).getresult(), + [('Hello, w\xc3\xb6rld!',)]) def testQueryWithUnicodeParams(self): query = self.c.query @@ -963,8 +977,9 @@ def testQueryWithUnicodeParams(self): query("select 'wörld'").getresult()[0][0] == 'wörld' except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") - self.assertEqual(query("select $1||', '||$2||'!'", - ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)]) + self.assertEqual( + query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult(), + [('Hello, wörld!',)]) def testQueryWithUnicodeParamsLatin1(self): query = self.c.query @@ -978,19 +993,22 @@ def testQueryWithUnicodeParamsLatin1(self): self.assertEqual(r, [('Hello, wörld!',)]) else: self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)]) - self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'", + self.assertRaises( + UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', u'мир')) query('set client_encoding=iso_8859_1') - r = query("select $1||', '||$2||'!'", - ('Hello', u'wörld')).getresult() + r = query( + "select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() if unicode_strings: self.assertEqual(r, [('Hello, wörld!',)]) else: self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)]) - self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'", + self.assertRaises( + UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', u'мир')) query('set client_encoding=sql_ascii') - self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'", + self.assertRaises( + UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', u'wörld')) def testQueryWithUnicodeParamsCyrillic(self): @@ -1000,38 +1018,44 @@ def testQueryWithUnicodeParamsCyrillic(self): query("select 'мир'").getresult()[0][0] == 'мир' except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") - self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'", + self.assertRaises( + UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', u'wörld')) - r = query("select $1||', '||$2||'!'", - ('Hello', u'мир')).getresult() + r = query( + "select $1||', '||$2||'!'", ('Hello', u'мир')).getresult() if unicode_strings: self.assertEqual(r, [('Hello, мир!',)]) else: self.assertEqual(r, [(u'Hello, мир!'.encode('cyrillic'),)]) query('set client_encoding=sql_ascii') - self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'", + self.assertRaises( + UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', u'мир!')) def testQueryWithMixedParams(self): - self.assertEqual(self.c.query("select $1+2,$2||', world!'", + self.assertEqual( + self.c.query("select $1+2,$2||', world!'", (1, 'Hello'),).getresult(), [(3, 'Hello, world!')]) - self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text", + self.assertEqual( + self.c.query("select $1::integer,$2::date,$3::text", (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')]) def testQueryWithDuplicateParams(self): - self.assertRaises(pg.ProgrammingError, - self.c.query, "select $1+$1", (1,)) - self.assertRaises(pg.ProgrammingError, - self.c.query, "select $1+$1", (1, 2)) + self.assertRaises( + pg.ProgrammingError, self.c.query, "select $1+$1", (1,)) + self.assertRaises( + pg.ProgrammingError, self.c.query, "select $1+$1", (1, 2)) def testQueryWithZeroParams(self): - self.assertEqual(self.c.query("select 1+1", [] - ).getresult(), [(2,)]) + self.assertEqual( + self.c.query("select 1+1", []).getresult(), [(2,)]) def testQueryWithGarbage(self): garbage = r"'\{}+()-#[]oo324" - self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,) - ).dictresult(), [{'garbage': garbage}]) + self.assertEqual( + self.c.query("select $1::text AS garbage", + (garbage,)).dictresult(), + [{'garbage': garbage}]) class TestPreparedQueries(unittest.TestCase): @@ -1056,8 +1080,8 @@ def testDuplicatePreparedStatement(self): self.assertRaises(pg.ProgrammingError, self.c.prepare, 'q', 'select 2') def testNonExistentPreparedStatement(self): - self.assertRaises(pg.OperationalError, - self.c.query_prepared, 'does-not-exist') + self.assertRaises( + pg.OperationalError, self.c.query_prepared, 'does-not-exist') def testUnnamedQueryWithoutParams(self): self.assertIsNone(self.c.prepare('', "select 'anon'")) @@ -1266,7 +1290,8 @@ def testDictIterate(self): self.assertIsInstance(r[1], dict) def testDictIterateTwoColumns(self): - r = self.c.query("select 1 as one, 2 as two" + r = self.c.query( + "select 1 as one, 2 as two" " union select 3 as one, 4 as two").dictiter() self.assertIsInstance(r, Iterable) r = list(r) @@ -1295,7 +1320,8 @@ def testNamedIterate(self): self.assertEqual(r[1].number, 4) def testNamedIterateTwoColumns(self): - r = self.c.query("select 1 as one, 2 as two" + r = self.c.query( + "select 1 as one, 2 as two" " union select 3 as one, 4 as two").namediter() self.assertIsInstance(r, Iterable) r = list(r) @@ -1744,9 +1770,9 @@ def testInserttableNullValues(self): def testInserttableMaxValues(self): data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), - True, '2999-12-31', '11:59:59', 1e99, - 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, - "1", "1234", "1234", "1234" * 100)] + True, '2999-12-31', '11:59:59', 1e99, + 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, + "1", "1234", "1234", "1234" * 100)] self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) @@ -1757,11 +1783,13 @@ def testInserttableByteValues(self): self.skipTest("database does not support utf8") # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' - row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00', + row_unicode = ( + 0, 0, long(0), False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") - row_bytes = tuple(s.encode('utf-8') - if isinstance(s, unicode) else s for s in row_unicode) + row_bytes = tuple( + s.encode('utf-8') if isinstance(s, unicode) else s + for s in row_unicode) data = [row_bytes] * 2 self.c.inserttable('test', data) if unicode_strings: @@ -1775,14 +1803,16 @@ def testInserttableUnicodeUtf8(self): self.skipTest("database does not support utf8") # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' - row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00', + row_unicode = ( + 0, 0, long(0), False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") data = [row_unicode] * 2 self.c.inserttable('test', data) if not unicode_strings: - row_bytes = tuple(s.encode('utf-8') - if isinstance(s, unicode) else s for s in row_unicode) + row_bytes = tuple( + s.encode('utf-8') if isinstance(s, unicode) else s + for s in row_unicode) data = [row_bytes] * 2 self.assertEqual(self.get_back(), data) @@ -1794,19 +1824,22 @@ def testInserttableUnicodeLatin1(self): self.skipTest("database does not support latin1") # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' - row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00', + row_unicode = ( + 0, 0, long(0), False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") data = [row_unicode] # cannot encode € sign with latin1 encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) - row_unicode = tuple(s.replace(u'€', u'¥') - if isinstance(s, unicode) else s for s in row_unicode) + row_unicode = tuple( + s.replace(u'€', u'¥') if isinstance(s, unicode) else s + for s in row_unicode) data = [row_unicode] * 2 self.c.inserttable('test', data) if not unicode_strings: - row_bytes = tuple(s.encode('latin1') - if isinstance(s, unicode) else s for s in row_unicode) + row_bytes = tuple( + s.encode('latin1') if isinstance(s, unicode) else s + for s in row_unicode) data = [row_bytes] * 2 self.assertEqual(self.get_back('latin1'), data) @@ -1819,14 +1852,16 @@ def testInserttableUnicodeLatin9(self): return # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' - row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00', + row_unicode = ( + 0, 0, long(0), False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") data = [row_unicode] * 2 self.c.inserttable('test', data) if not unicode_strings: - row_bytes = tuple(s.encode('latin9') - if isinstance(s, unicode) else s for s in row_unicode) + row_bytes = tuple( + s.encode('latin9') if isinstance(s, unicode) else s + for s in row_unicode) data = [row_bytes] * 2 self.assertEqual(self.get_back('latin9'), data) @@ -1834,7 +1869,8 @@ def testInserttableNoEncoding(self): self.c.query("set client_encoding=sql_ascii") # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' - row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00', + row_unicode = ( + 0, 0, long(0), False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") data = [row_unicode] @@ -2106,7 +2142,8 @@ def testSetDecimalPoint(self): en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8' en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar' de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8' - de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25', + de_money = ( + '34,25€', '34,25 €', '€34,25', '€ 34,25', 'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM') # first try with English localization (using the point) for lc in en_locales: @@ -2390,8 +2427,8 @@ def testSetRowFactorySize(self): info = pg._row_factory.cache_info() self.assertEqual(info.maxsize, maxsize) self.assertEqual(info.hits + info.misses, 6) - self.assertEqual(info.hits, - 0 if maxsize is not None and maxsize < 2 else 4) + self.assertEqual( + info.hits, 0 if maxsize is not None and maxsize < 2 else 4) class TestStandaloneEscapeFunctions(unittest.TestCase): diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 5218b98d..d9735437 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -205,11 +205,11 @@ class TestParseArray(unittest.TestCase): ('{{{17,18,19},{14,15,16},{11,12,13}},' '{{27,28,29},{24,25,26},{21,22,23}},' '{{37,38,39},{34,35,36},{31,32,33}}}', int, - [[[17, 18, 19], [14, 15, 16], [11, 12, 13]], - [[27, 28, 29], [24, 25, 26], [21, 22, 23]], - [[37, 38, 39], [34, 35, 36], [31, 32, 33]]]), + [[[17, 18, 19], [14, 15, 16], [11, 12, 13]], + [[27, 28, 29], [24, 25, 26], [21, 22, 23]], + [[37, 38, 39], [34, 35, 36], [31, 32, 33]]]), ('{{"breakfast", "consulting"}, {"meeting", "lunch"}}', str, - [['breakfast', 'consulting'], ['meeting', 'lunch']]), + [['breakfast', 'consulting'], ['meeting', 'lunch']]), ('[1:3]={1,2,3}', int, [1, 2, 3]), ('[-1:1]={1,2,3}', int, [1, 2, 3]), ('[-1:+1]={1,2,3}', int, [1, 2, 3]), @@ -221,9 +221,9 @@ class TestParseArray(unittest.TestCase): ('[1:]={1,2,3}', int, ValueError), ('[:3]={1,2,3}', int, ValueError), ('[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}', - int, [[[1, 2, 3], [4, 5, 6]]]), + int, [[[1, 2, 3], [4, 5, 6]]]), (' [1:1] [-2:-1] [3:5] = { { { 1 , 2 , 3 }, {4 , 5 , 6 } } }', - int, [[[1, 2, 3], [4, 5, 6]]]), + int, [[[1, 2, 3], [4, 5, 6]]]), ('[1:1][3:5]={{1,2,3},{4,5,6}}', int, [[1, 2, 3], [4, 5, 6]]), ('[3:5]={{1,2,3},{4,5,6}}', int, ValueError), ('[1:1][-2:-1][3:5]={{1,2,3},{4,5,6}}', int, ValueError)] @@ -309,7 +309,9 @@ def testParserCast(self): self.assertEqual(f('{a}', None), ['a']) self.assertRaises(ValueError, f, '{a}', int) self.assertEqual(f('{a}', str), ['a']) - cast = lambda s: '%s is ok' % s + + def cast(s): + return '%s is ok' % s self.assertEqual(f('{a}', cast), ['a is ok']) def testParserDelim(self): @@ -549,7 +551,9 @@ def testParserCastUniform(self): self.assertEqual(f('(a)', None), ('a',)) self.assertRaises(ValueError, f, '(a)', int) self.assertEqual(f('(a)', str), ('a',)) - cast = lambda s: '%s is ok' % s + + def cast(s): + return '%s is ok' % s self.assertEqual(f('(a)', cast), ('a is ok',)) def testParserCastNonUniform(self): @@ -568,19 +572,25 @@ def testParserCastNonUniform(self): self.assertRaises(ValueError, f, '(1,a)', [str, int]) self.assertEqual(f('(a,1)', [str, int]), ('a', 1)) self.assertRaises(ValueError, f, '(a,1)', [int, str]) - self.assertEqual(f('(1,a,2,b,3,c)', - [int, str, int, str, int, str]), (1, 'a', 2, 'b', 3, 'c')) - self.assertEqual(f('(1,a,2,b,3,c)', - (int, str, int, str, int, str)), (1, 'a', 2, 'b', 3, 'c')) - cast1 = lambda s: '%s is ok' % s + self.assertEqual( + f('(1,a,2,b,3,c)', [int, str, int, str, int, str]), + (1, 'a', 2, 'b', 3, 'c')) + self.assertEqual( + f('(1,a,2,b,3,c)', (int, str, int, str, int, str)), + (1, 'a', 2, 'b', 3, 'c')) + + def cast1(s): + return '%s is ok' % s self.assertEqual(f('(a)', [cast1]), ('a is ok',)) - cast2 = lambda s: 'and %s is ok, too' % s - self.assertEqual(f('(a,b)', [cast1, cast2]), - ('a is ok', 'and b is ok, too')) + + def cast2(s): + return 'and %s is ok, too' % s + self.assertEqual( + f('(a,b)', [cast1, cast2]), ('a is ok', 'and b is ok, too')) self.assertRaises(ValueError, f, '(a)', [cast1, cast2]) self.assertRaises(ValueError, f, '(a,b,c)', [cast1, cast2]) - self.assertEqual(f('(1,2,3,4,5,6)', - [int, float, str, None, cast1, cast2]), + self.assertEqual( + f('(1,2,3,4,5,6)', [int, float, str, None, cast1, cast2]), (1, 2.0, '3', '4', '5 is ok', 'and 6 is ok, too')) def testParserDelim(self): @@ -656,10 +666,9 @@ class TestParseHStore(unittest.TestCase): ('"k=>v', ValueError), ('k=>"v', ValueError), ('"1-a" => "anything at all"', {'1-a': 'anything at all'}), - ('k => v, foo => bar, baz => whatever,' - ' "1-a" => "anything at all"', - {'k': 'v', 'foo': 'bar', 'baz': 'whatever', - '1-a': 'anything at all'}), + ('k => v, foo => bar, baz => whatever, "1-a" => "anything at all"', + {'k': 'v', 'foo': 'bar', 'baz': 'whatever', + '1-a': 'anything at all'}), ('"Hello, World!"=>"Hi!"', {'Hello, World!': 'Hi!'}), ('"Hi!"=>"Hello, World!"', {'Hi!': 'Hello, World!'}), (r'"k=>v"=>k\=\>v', {'k=>v': 'k=>v'}), @@ -691,10 +700,10 @@ class TestCastInterval(unittest.TestCase): ('-1:00:00', '-01:00:00', '@ -1 hour', 'PT-1H')), ((0, 0, 0, 1, 0, 0, 0), ('0-0 0 1:00:00', '0 years 0 mons 0 days 01:00:00', - '@ 0 years 0 mons 0 days 1 hour', 'P0Y0M0DT1H')), + '@ 0 years 0 mons 0 days 1 hour', 'P0Y0M0DT1H')), ((0, 0, 0, -1, 0, 0, 0), ('-0-0 -1:00:00', '0 years 0 mons 0 days -01:00:00', - '@ 0 years 0 mons 0 days -1 hour', 'P0Y0M0DT-1H')), + '@ 0 years 0 mons 0 days -1 hour', 'P0Y0M0DT-1H')), ((0, 0, 1, 0, 0, 0, 0), ('1 0:00:00', '1 day', '@ 1 day', 'P1D')), ((0, 0, -1, 0, 0, 0, 0), @@ -848,7 +857,8 @@ def testCastInterval(self): f = pg.cast_interval years, mons, days, hours, mins, secs, usecs = result days += 365 * years + 30 * mons - interval = timedelta(days=days, hours=hours, minutes=mins, + interval = timedelta( + days=days, hours=hours, minutes=mins, seconds=secs, microseconds=usecs) for value in values: self.assertEqual(f(value), interval) diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 65552b5c..0206e4ca 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -191,7 +191,8 @@ def testNotifyWrongEvent(self): try: handler() except pg.DatabaseError as error: - self.assertEqual(str(error), + self.assertEqual( + str(error), 'Listening for "good_event" and "stop_good_event",' ' but notified of "bad_event"') self.assertIsNotNone(self.timeout) diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 51d296ef..f86c2bee 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -126,8 +126,8 @@ class TestCopy(unittest.TestCase): @staticmethod def connect(): - return pgdb.connect(database=dbname, - host='%s:%d' % (dbhost or '', dbport or -1)) + return pgdb.connect( + database=dbname, host='%s:%d' % (dbhost or '', dbport or -1)) @classmethod def setUpClass(cls): @@ -244,8 +244,8 @@ def test_bad_params(self): self.assertRaises(TypeError, call, '0\t', 'copytest', null=42) self.assertRaises(ValueError, call, '0\t', 'copytest', size='bad') self.assertRaises(TypeError, call, '0\t', 'copytest', columns=42) - self.assertRaises(ValueError, call, b'', 'copytest', - format='binary', sep=',') + self.assertRaises( + ValueError, call, b'', 'copytest', format='binary', sep=',') def test_input_string(self): ret = self.copy_from('42\tHello, world!') @@ -353,13 +353,13 @@ def test_csv_with_sep(self): self.check_rowcount() def test_binary(self): - self.assertRaises(IOError, self.copy_from, - b'NOPGCOPY\n', format='binary') + self.assertRaises( + IOError, self.copy_from, b'NOPGCOPY\n', format='binary') self.check_rowcount(-1) def test_binary_with_sep(self): - self.assertRaises(ValueError, self.copy_from, - '', format='binary', sep='\t') + self.assertRaises( + ValueError, self.copy_from, '', format='binary', sep='\t') def test_binary_with_unicode(self): self.assertRaises(ValueError, self.copy_from, u'', format='binary') @@ -395,8 +395,8 @@ def test_size_negative(self): self.check_rowcount() def test_size_invalid(self): - self.assertRaises(TypeError, - self.copy_from, self.data_file, size='invalid') + self.assertRaises( + TypeError, self.copy_from, self.data_file, size='invalid') class TestCopyTo(TestCopy): @@ -545,8 +545,8 @@ def test_binary_with_sep(self): self.assertRaises(ValueError, self.copy_to, format='binary', sep='\t') def test_binary_with_unicode(self): - self.assertRaises(ValueError, self.copy_to, - format='binary', decode=True) + self.assertRaises( + ValueError, self.copy_to, format='binary', decode=True) def test_query(self): self.assertRaises( From 0874176176edc595b57fccbeaa34a626256c116e Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 May 2020 22:23:30 +0200 Subject: [PATCH 014/194] Document escape_bytea properly --- docs/contents/changelog.rst | 8 ++++---- docs/contents/pg/db_wrapper.rst | 6 +++--- docs/contents/pg/module.rst | 7 ++++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 7e2731c1..3cceabf1 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -5,10 +5,10 @@ Version 5.2 (to be released) ---------------------------- - We now require Python version 2.7 or 3.5 and newer - Changes to the classic PyGreSQL module (pg): - - New module level function `get_pqlib_version()` that gets the version - of the pqlib used by PyGreSQL (needs PostgreSQL >= 9.1 on the client). - - New query method `memsize()` that gets the memory size allocated by - the query (needs PostgreSQL >= 12 on the client). + - New module level function `get_pqlib_version()` that gets the version + of the pqlib used by PyGreSQL (needs PostgreSQL >= 9.1 on the client). + - New query method `memsize()` that gets the memory size allocated by + the query (needs PostgreSQL >= 12 on the client). Version 5.1.2 (2020-04-19) diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 2727f8fc..540871fc 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -851,9 +851,9 @@ properties (such as character encoding). Escape binary data for use within SQL as type ``bytea`` - :param str datastring: string containing the binary data that is to be escaped + :param bytes/str datastring: the binary data that is to be escaped :returns: the escaped string - :rtype: str + :rtype: bytes/str Similar to the module function :func:`pg.escape_bytea` with the same name, but the behavior of this method is adjusted depending on the connection @@ -866,7 +866,7 @@ unescape_bytea -- unescape data retrieved from the database Unescape ``bytea`` data that has been retrieved as text - :param datastring: the ``bytea`` data string that has been retrieved as text + :param str string: the ``bytea`` string that has been retrieved as text :returns: byte string containing the binary data :rtype: bytes diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index 1a92f283..4a1cbea1 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -295,12 +295,13 @@ escape_bytea -- escape binary data for use within SQL escape binary data for use within SQL as type ``bytea`` - :param str datastring: string containing the binary data that is to be escaped + :param bytes/str datastring: the binary data that is to be escaped :returns: the escaped string - :rtype: str + :rtype: bytes/str :raises TypeError: bad argument type, or too many arguments Escapes binary data for use within an SQL command with the type ``bytea``. +The return value will have the same type as the given *datastring*. As with :func:`escape_string`, this is only used when inserting data directly into an SQL command string. @@ -320,7 +321,7 @@ unescape_bytea -- unescape data that has been retrieved as text Unescape ``bytea`` data that has been retrieved as text - :param str datastring: the ``bytea`` data string that has been retrieved as text + :param str string: the ``bytea`` string that has been retrieved as text :returns: byte string containing the binary data :rtype: bytes :raises TypeError: bad argument type, or too many arguments From 9dbc6b984ada1053626d0632e4de89ee64be6320 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 9 Jun 2020 16:36:25 +0200 Subject: [PATCH 015/194] Remove www from CNAME --- .github/workflows/release-docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release-docs.yml b/.github/workflows/release-docs.yml index 2b77e8db..1f5e042d 100644 --- a/.github/workflows/release-docs.yml +++ b/.github/workflows/release-docs.yml @@ -33,6 +33,6 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} publish_branch: gh-pages publish_dir: docs/_build/html - cname: www.pygresql.org + cname: pygresql.org enable_jekyll: false force_orphan: true From b1e040e989b5b1b75f42c1103562bfe8f09f93c3 Mon Sep 17 00:00:00 2001 From: Tyler Ramer Date: Sat, 13 Jun 2020 17:55:01 -0400 Subject: [PATCH 016/194] Sanitize parsing of kwargs to handle quote and backslash characters (#40) --- pgdb.py | 3 +-- tests/test_dbapi20.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pgdb.py b/pgdb.py index cdfb171e..398c9ab1 100644 --- a/pgdb.py +++ b/pgdb.py @@ -1662,10 +1662,9 @@ def connect(dsn=None, value = str(value) if not value or ' ' in value: value = "'%s'" % (value.replace( - "'", "\\'").replace('\\', '\\\\'),) + '\\', '\\\\').replace("'", "\\'")) dbname.append('%s=%s' % (kw, value)) dbname = ' '.join(dbname) - # open the connection cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) return Connection(cnx) diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 00280eec..f311a8fa 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -80,7 +80,7 @@ def test_version(self): self.assertEqual(pgdb.__version__, v) def test_connect_kwargs(self): - application_name = 'PyGreSQL DB API 2.0 Test' + application_name = 'PyGreSQL DB API 2.0 Test with\' quote and \\\\backslash' self.connect_kw_args['application_name'] = application_name con = self._connect() cur = con.cursor() From a4de3f5a7e9e496c052350c0c851b70a80fafd63 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 14 Jun 2020 00:53:51 +0200 Subject: [PATCH 017/194] Update changelog --- docs/contents/changelog.rst | 4 ++++ tests/test_dbapi20.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 3cceabf1..7e01520a 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -10,6 +10,10 @@ Version 5.2 (to be released) - New query method `memsize()` that gets the memory size allocated by the query (needs PostgreSQL >= 12 on the client). +- Changes to the DB-API 2 module (pgdb): + - Connection arguments containing single quotes caused problems + (reported and fixed by Tyler Ramer and Jamie McAtamney). + Version 5.1.2 (2020-04-19) -------------------------- diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index f311a8fa..a7d0ac36 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -80,7 +80,7 @@ def test_version(self): self.assertEqual(pgdb.__version__, v) def test_connect_kwargs(self): - application_name = 'PyGreSQL DB API 2.0 Test with\' quote and \\\\backslash' + application_name = 'PyGreSQL DB API 2.0 Test' self.connect_kw_args['application_name'] = application_name con = self._connect() cur = con.cursor() @@ -88,6 +88,15 @@ def test_connect_kwargs(self): " where application_name = %s", (application_name,)) self.assertEqual(cur.fetchone(), (application_name,)) + def test_connect_kwargs_with_special_chars(self): + special_name = 'Single \' and double " quote and \\ backslash!' + self.connect_kw_args['application_name'] = special_name + con = self._connect() + cur = con.cursor() + cur.execute("select application_name from pg_stat_activity" + " where application_name = %s", (special_name,)) + self.assertEqual(cur.fetchone(), (special_name,)) + def test_percent_sign(self): con = self._connect() cur = con.cursor() From 3a3b6e5b3a287ad8b5d37e7b776d6d8b053aabfd Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Jun 2020 14:11:34 +0200 Subject: [PATCH 018/194] Update dbapi20.py from 1.5 to 1.15 (#31) --- tests/dbapi20.py | 764 ++++++++++++++++++++++++----------------------- 1 file changed, 398 insertions(+), 366 deletions(-) diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 5d77267e..37027183 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -1,21 +1,35 @@ -#!/usr/bin/python -'''Python DB API 2.0 driver compliance unit test suite. +"""Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. -''' +""" -__version__ = '1.5' +__version__ = '1.15.0' import unittest - import time +try: # noinspection PyUnresolvedReferences + _BaseException = StandardError +except Exception: # Python 2 + _BaseException = Exception + +try: # noinspection PyUnresolvedReferences + unicode +except NameError: # Python 3 + unicode = str + + +def str2bytes(sval): + if str is not unicode and isinstance(sval, str): + sval = sval.decode("latin1") + return sval.encode("latin1") # python 3 make unicode into bytes + class DatabaseAPI20Test(unittest.TestCase): ''' Test a database self.driver for DB API 2.0 compatibility. This implementation tests Gadfly, but the TestCase is structured so that other self.drivers can subclass this - test case to ensure compiliance with the DB-API. It is + test case to ensure compliance with the DB-API. It is expected that this TestCase may be expanded in the future if ambiguities or edge conditions are discovered. @@ -36,57 +50,62 @@ class mytest(dbapi20.DatabaseAPI20Test): # The self.driver module. This should be the module where the 'connect' # method is to be found driver = None - connect_args = () # List of arguments to pass to connect - connect_kw_args = {} # Keyword arguments for connect - table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + connect_args = () # List of arguments to pass to connect + connect_kw_args = {} # Keyword arguments for connect + table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix - ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix + ddl2 = 'create table %sbarflys (name varchar(20), drink varchar(30))' % table_prefix xddl1 = 'drop table %sbooze' % table_prefix xddl2 = 'drop table %sbarflys' % table_prefix + insert = 'insert' - lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. - def executeDDL1(self,cursor): + def executeDDL1(self, cursor): cursor.execute(self.ddl1) - def executeDDL2(self,cursor): + def executeDDL2(self, cursor): cursor.execute(self.ddl2) def setUp(self): - """self.drivers should override this method to perform required setup + ''' self.drivers should override this method to perform required setup if any is necessary, such as creating the database. - """ + ''' pass def tearDown(self): - """self.drivers should override this method to perform required cleanup + ''' self.drivers should override this method to perform required cleanup if any is necessary, such as deleting the test database. The default drops the tables that may be created. - """ - con = self._connect() + ''' try: - cur = con.cursor() - for ddl in (self.xddl1,self.xddl2): - try: - cur.execute(ddl) - con.commit() - except self.driver.Error: - # Assume table didn't exist. Other tests will check if - # execute is busted. - pass - finally: - con.close() + con = self._connect() + try: + cur = con.cursor() + for ddl in (self.xddl1, self.xddl2): + try: + cur.execute(ddl) + con.commit() + except self.driver.Error: + # Assume table didn't exist. Other tests will check if + # execute is busted. + pass + finally: + con.close() + except _BaseException: + pass def _connect(self): try: - return self.driver.connect( - *self.connect_args,**self.connect_kw_args - ) + r = self.driver.connect( + *self.connect_args, **self.connect_kw_args + ) except AttributeError: self.fail("No connect method found in self.driver module") + return r def test_connect(self): con = self._connect() @@ -97,7 +116,7 @@ def test_apilevel(self): # Must exist apilevel = self.driver.apilevel # Must equal 2.0 - self.assertEqual(apilevel,'2.0') + self.assertEqual(apilevel, '2.0') except AttributeError: self.fail("Driver doesn't define apilevel") @@ -106,7 +125,7 @@ def test_threadsafety(self): # Must exist threadsafety = self.driver.threadsafety # Must be a valid value - self.assertTrue(threadsafety in (0,1,2,3)) + self.assertTrue(threadsafety in (0, 1, 2, 3)) except AttributeError: self.fail("Driver doesn't define threadsafety") @@ -116,48 +135,32 @@ def test_paramstyle(self): paramstyle = self.driver.paramstyle # Must be a valid value self.assertTrue(paramstyle in ( - 'qmark','numeric','named','format','pyformat' - )) + 'qmark', 'numeric', 'named', 'format', 'pyformat' + )) except AttributeError: self.fail("Driver doesn't define paramstyle") def test_Exceptions(self): - """Make sure required exceptions exist, and are in the - defined hierarchy. - """ - self.assertTrue(issubclass(self.driver.Warning,Exception)) - self.assertTrue(issubclass(self.driver.Error,Exception)) - self.assertTrue( - issubclass(self.driver.InterfaceError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.DatabaseError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.OperationalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.IntegrityError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.InternalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.ProgrammingError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.NotSupportedError,self.driver.Error) - ) + # Make sure required exceptions exist, and are in the + # defined hierarchy. + self.assertTrue(issubclass(self.driver.Warning, _BaseException)) + self.assertTrue(issubclass(self.driver.Error, Exception)) + + self.assertTrue(issubclass(self.driver.InterfaceError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.DatabaseError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.OperationalError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.IntegrityError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.InternalError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.ProgrammingError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.NotSupportedError, self.driver.Error)) def test_ExceptionsAsConnectionAttributes(self): - """Optional extension - - Test for the optional DB API 2.0 extension, where the exceptions - are exposed as attributes on the Connection object - I figure this optional extension will be implemented by any - driver author who is using this test suite, so it is enabled - by default. - """ + # OPTIONAL EXTENSION + # Test for the optional DB API 2.0 extension, where the exceptions + # are exposed as attributes on the Connection object + # I figure this optional extension will be implemented by any + # driver author who is using this test suite, so it is enabled + # by default. con = self._connect() drv = self.driver self.assertTrue(con.Warning is drv.Warning) @@ -182,7 +185,7 @@ def test_rollback(self): con = self._connect() # If rollback is defined, it should either work or throw # the documented exception - if hasattr(con,'rollback'): + if hasattr(con, 'rollback'): try: con.rollback() except self.driver.NotSupportedError: @@ -203,14 +206,14 @@ def test_cursor_isolation(self): cur1 = con.cursor() cur2 = con.cursor() self.executeDDL1(cur1) - cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) + cur1.execute("%s into %sbooze values ('Victoria Bitter')" % ( + self.insert, self.table_prefix + )) cur2.execute("select name from %sbooze" % self.table_prefix) booze = cur2.fetchall() - self.assertEqual(len(booze),1) - self.assertEqual(len(booze[0]),1) - self.assertEqual(booze[0][0],'Victoria Bitter') + self.assertEqual(len(booze), 1) + self.assertEqual(len(booze[0]), 1) + self.assertEqual(booze[0][0], 'Victoria Bitter') finally: con.close() @@ -219,31 +222,31 @@ def test_description(self): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.description,None, - 'cursor.description should be none after executing a ' - 'statement that can return no rows (such as DDL)' - ) + self.assertEqual(cur.description, None, + 'cursor.description should be none after executing a ' + 'statement that can return no rows (such as DDL)' + ) cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(len(cur.description),1, - 'cursor.description describes too many columns' - ) - self.assertEqual(len(cur.description[0]),7, - 'cursor.description[x] tuples must have 7 elements' - ) - self.assertEqual(cur.description[0][0].lower(),'name', - 'cursor.description[x][0] must return column name' - ) - self.assertEqual(cur.description[0][1],self.driver.STRING, - 'cursor.description[x][1] must return column type. Got %r' - % cur.description[0][1] - ) + self.assertEqual(len(cur.description), 1, + 'cursor.description describes too many columns' + ) + self.assertEqual(len(cur.description[0]), 7, + 'cursor.description[x] tuples must have 7 elements' + ) + self.assertEqual(cur.description[0][0].lower(), 'name', + 'cursor.description[x][0] must return column name' + ) + self.assertEqual(cur.description[0][1], self.driver.STRING, + 'cursor.description[x][1] must return column type. Got %r' + % cur.description[0][1] + ) # Make sure self.description gets reset self.executeDDL2(cur) - self.assertEqual(cur.description,None, - 'cursor.description not being set to None when executing ' - 'no-result statements (eg. DDL)' - ) + self.assertEqual(cur.description, None, + 'cursor.description not being set to None when executing ' + 'no-result statements (eg. DDL)' + ) finally: con.close() @@ -252,47 +255,48 @@ def test_rowcount(self): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount should be -1 after executing no-result ' - 'statements' - ) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number or rows inserted, or ' - 'set to -1 after executing an insert statement' - ) + self.assertTrue(cur.rowcount in (-1, 0), # Bug #543885 + 'cursor.rowcount should be -1 or 0 after executing no-result ' + 'statements' + ) + cur.execute("%s into %sbooze values ('Victoria Bitter')" % ( + self.insert, self.table_prefix + )) + self.assertTrue(cur.rowcount in (-1, 1), + 'cursor.rowcount should == number or rows inserted, or ' + 'set to -1 after executing an insert statement' + ) cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) + self.assertTrue(cur.rowcount in (-1, 1), + 'cursor.rowcount should == number of rows returned, or ' + 'set to -1 after executing a select statement' + ) self.executeDDL2(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount not being reset to -1 after executing ' - 'no-result statements' - ) + self.assertTrue(cur.rowcount in (-1, 0), # Bug #543885 + 'cursor.rowcount should be -1 or 0 after executing no-result ' + 'statements' + ) finally: con.close() lower_func = 'lower' + def test_callproc(self): con = self._connect() try: cur = con.cursor() - if self.lower_func and hasattr(cur,'callproc'): - r = cur.callproc(self.lower_func,('FOO',)) - self.assertEqual(len(r),1) - self.assertEqual(r[0],'FOO') + if self.lower_func and hasattr(cur, 'callproc'): + r = cur.callproc(self.lower_func, ('FOO',)) + self.assertEqual(len(r), 1) + self.assertEqual(r[0], 'FOO') r = cur.fetchall() - self.assertEqual(len(r),1,'callproc produced no result set') - self.assertEqual(len(r[0]),1, - 'callproc produced invalid result set' - ) - self.assertEqual(r[0][0],'foo', - 'callproc produced invalid results' - ) + self.assertEqual(len(r), 1, 'callproc produced no result set') + self.assertEqual(len(r[0]), 1, + 'callproc produced invalid result set' + ) + self.assertEqual(r[0][0], 'foo', + 'callproc produced invalid results' + ) finally: con.close() @@ -305,14 +309,18 @@ def test_close(self): # cursor.execute should raise an Error if called after connection # closed - self.assertRaises(self.driver.Error,self.executeDDL1,cur) + self.assertRaises(self.driver.Error, self.executeDDL1, cur) # connection.commit should raise an Error if called after connection' # closed.' - self.assertRaises(self.driver.Error,con.commit) + self.assertRaises(self.driver.Error, con.commit) + def test_non_idempotent_close(self): + con = self._connect() + con.close() # connection.close should raise an Error if called more than once - self.assertRaises(self.driver.Error,con.close) + # !!! reasonable persons differ about the usefulness of this test and this feature !!! + self.assertRaises(self.driver.Error, con.close) def test_execute(self): con = self._connect() @@ -322,105 +330,125 @@ def test_execute(self): finally: con.close() - def _paraminsert(self,cur): - self.executeDDL1(cur) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix + def _paraminsert(self, cur): + self.executeDDL2(cur) + cur.execute( + "%s into %sbarflys values ('Victoria Bitter', 'thi%%s :may ca%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix )) - self.assertTrue(cur.rowcount in (-1,1)) + self.assertTrue(cur.rowcount in (-1, 1)) if self.driver.paramstyle == 'qmark': cur.execute( - 'insert into %sbooze values (?)' % self.table_prefix, + "%s into %sbarflys values (?, 'thi%%s :may ca%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix), ("Cooper's",) - ) + ) elif self.driver.paramstyle == 'numeric': cur.execute( - 'insert into %sbooze values (:1)' % self.table_prefix, + "%s into %sbarflys values (:1, 'thi%%s :may ca%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix), ("Cooper's",) - ) + ) elif self.driver.paramstyle == 'named': cur.execute( - 'insert into %sbooze values (:beer)' % self.table_prefix, - {'beer':"Cooper's"} - ) + "%s into %sbarflys values (:beer, 'thi%%s :may ca%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix), + {'beer': "Cooper's"} + ) elif self.driver.paramstyle == 'format': cur.execute( - 'insert into %sbooze values (%%s)' % self.table_prefix, + "%s into %sbarflys values (%%s, 'thi%%%%s :may ca%%%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix), ("Cooper's",) - ) + ) elif self.driver.paramstyle == 'pyformat': cur.execute( - 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, - {'beer':"Cooper's"} - ) + "%s into %sbarflys values (%%(beer)s, 'thi%%%%s :may ca%%%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix), + {'beer': "Cooper's"} + ) else: self.fail('Invalid paramstyle') - self.assertTrue(cur.rowcount in (-1,1)) + self.assertTrue(cur.rowcount in (-1, 1)) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute('select name, drink from %sbarflys' % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') - beers = [res[0][0],res[1][0]] + self.assertEqual(len(res), 2, 'cursor.fetchall returned too few rows') + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Cooper's", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) - self.assertEqual(beers[1],"Victoria Bitter", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) + self.assertEqual(beers[0], "Cooper's", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + self.assertEqual(beers[1], "Victoria Bitter", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + trouble = "thi%s :may ca%(u)se? troub:1e" + self.assertEqual(res[0][1], trouble, + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly. Got=%s, Expected=%s' % ( + repr(res[0][1]), repr(trouble))) + self.assertEqual(res[1][1], trouble, + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly. Got=%s, Expected=%s' % ( + repr(res[1][1]), repr(trouble) + )) def test_executemany(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) - largs = [ ("Cooper's",) , ("Boag's",) ] - margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] + largs = [("Cooper's",), ("Boag's",)] + margs = [{'beer': "Cooper's"}, {'beer': "Boag's"}] if self.driver.paramstyle == 'qmark': cur.executemany( - 'insert into %sbooze values (?)' % self.table_prefix, + '%s into %sbooze values (?)' % ( + self.insert, self.table_prefix), largs - ) + ) elif self.driver.paramstyle == 'numeric': cur.executemany( - 'insert into %sbooze values (:1)' % self.table_prefix, + '%s into %sbooze values (:1)' % ( + self.insert, self.table_prefix), largs - ) + ) elif self.driver.paramstyle == 'named': cur.executemany( - 'insert into %sbooze values (:beer)' % self.table_prefix, + '%s into %sbooze values (:beer)' % ( + self.insert, self.table_prefix), margs - ) + ) elif self.driver.paramstyle == 'format': cur.executemany( - 'insert into %sbooze values (%%s)' % self.table_prefix, + '%s into %sbooze values (%%s)' % ( + self.insert, self.table_prefix), largs - ) + ) elif self.driver.paramstyle == 'pyformat': cur.executemany( - 'insert into %sbooze values (%%(beer)s)' % ( - self.table_prefix - ), + '%s into %sbooze values (%%(beer)s)' % ( + self.insert, self.table_prefix + ), margs - ) + ) else: self.fail('Unknown paramstyle') - self.assertTrue(cur.rowcount in (-1,2), - 'insert using cursor.executemany set cursor.rowcount to ' - 'incorrect value %r' % cur.rowcount - ) + self.assertTrue(cur.rowcount in (-1, 2), + 'insert using cursor.executemany set cursor.rowcount to ' + 'incorrect value %r' % cur.rowcount + ) cur.execute('select name from %sbooze' % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res),2, - 'cursor.fetchall retrieved incorrect number of rows' - ) - beers = [res[0][0],res[1][0]] + self.assertEqual(len(res), 2, + 'cursor.fetchall retrieved incorrect number of rows' + ) + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') - self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + self.assertEqual(beers[0], "Boag's", 'incorrect data retrieved') + self.assertEqual(beers[1], "Cooper's", 'incorrect data retrieved') finally: con.close() @@ -431,39 +459,39 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called before # executing a select-type query - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows self.executeDDL1(cur) - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + self.assertEqual(cur.fetchone(), None, + 'cursor.fetchone should return None if a query retrieves ' + 'no rows' + ) + self.assertTrue(cur.rowcount in (-1, 0)) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertRaises(self.driver.Error,cur.fetchone) + cur.execute("%s into %sbooze values ('Victoria Bitter')" % ( + self.insert, self.table_prefix + )) + self.assertRaises(self.driver.Error, cur.fetchone) cur.execute('select name from %sbooze' % self.table_prefix) r = cur.fetchone() - self.assertEqual(len(r),1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0],'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if no more rows available' - ) - self.assertTrue(cur.rowcount in (-1,1)) + self.assertEqual(len(r), 1, + 'cursor.fetchone should have retrieved a single row' + ) + self.assertEqual(r[0], 'Victoria Bitter', + 'cursor.fetchone retrieved incorrect data' + ) + self.assertEqual(cur.fetchone(), None, + 'cursor.fetchone should return None if no more rows available' + ) + self.assertTrue(cur.rowcount in (-1, 1)) finally: con.close() @@ -474,16 +502,17 @@ def test_fetchone(self): 'Redback', 'Victoria Bitter', 'XXXX' - ] + ] def _populate(self): - """Return a list of sql commands to setup the DB for the fetch + ''' Return a list of sql commands to setup the DB for the fetch tests. - """ + ''' populate = [ - "insert into %sbooze values ('%s')" % (self.table_prefix,s) - for s in self.samples - ] + "%s into %sbooze values ('%s')" % ( + self.insert, self.table_prefix, s) + for s in self.samples + ] return populate def test_fetchmany(self): @@ -492,8 +521,8 @@ def test_fetchmany(self): cur = con.cursor() # cursor.fetchmany should raise an Error if called without - #issuing a query - self.assertRaises(self.driver.Error,cur.fetchmany,4) + # issuing a query + self.assertRaises(self.driver.Error, cur.fetchmany, 4) self.executeDDL1(cur) for sql in self._populate(): @@ -501,69 +530,69 @@ def test_fetchmany(self): cur.execute('select name from %sbooze' % self.table_prefix) r = cur.fetchmany() - self.assertEqual(len(r),1, - 'cursor.fetchmany retrieved incorrect number of rows, ' - 'default of arraysize is one.' - ) - cur.arraysize=10 - r = cur.fetchmany(3) # Should get 3 rows - self.assertEqual(len(r),3, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should get 2 more - self.assertEqual(len(r),2, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should be an empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence after ' - 'results are exhausted' - ) - self.assertTrue(cur.rowcount in (-1,6)) + self.assertEqual(len(r), 1, + 'cursor.fetchmany retrieved incorrect number of rows, ' + 'default of arraysize is one.' + ) + cur.arraysize = 10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual(len(r), 3, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual(len(r), 2, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual(len(r), 0, + 'cursor.fetchmany should return an empty sequence after ' + 'results are exhausted' + ) + self.assertTrue(cur.rowcount in (-1, 6)) # Same as above, using cursor.arraysize - cur.arraysize=4 + cur.arraysize = 4 cur.execute('select name from %sbooze' % self.table_prefix) - r = cur.fetchmany() # Should get 4 rows - self.assertEqual(len(r),4, - 'cursor.arraysize not being honoured by fetchmany' - ) - r = cur.fetchmany() # Should get 2 more - self.assertEqual(len(r),2) - r = cur.fetchmany() # Should be an empty sequence - self.assertEqual(len(r),0) - self.assertTrue(cur.rowcount in (-1,6)) - - cur.arraysize=6 + r = cur.fetchmany() # Should get 4 rows + self.assertEqual(len(r), 4, + 'cursor.arraysize not being honoured by fetchmany' + ) + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r), 2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r), 0) + self.assertTrue(cur.rowcount in (-1, 6)) + + cur.arraysize = 6 cur.execute('select name from %sbooze' % self.table_prefix) - rows = cur.fetchmany() # Should get all rows - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows),6) - self.assertEqual(len(rows),6) + rows = cur.fetchmany() # Should get all rows + self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual(len(rows), 6) + self.assertEqual(len(rows), 6) rows = [r[0] for r in rows] rows.sort() # Make sure we get the right data back out - for i in range(0,6): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved by cursor.fetchmany' - ) - - rows = cur.fetchmany() # Should return an empty list - self.assertEqual(len(rows),0, - 'cursor.fetchmany should return an empty sequence if ' - 'called after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,6)) + for i in range(0, 6): + self.assertEqual(rows[i], self.samples[i], + 'incorrect data retrieved by cursor.fetchmany' + ) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual(len(rows), 0, + 'cursor.fetchmany should return an empty sequence if ' + 'called after the whole result set has been fetched' + ) + self.assertTrue(cur.rowcount in (-1, 6)) self.executeDDL2(cur) cur.execute('select name from %sbarflys' % self.table_prefix) - r = cur.fetchmany() # Should get empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence if ' - 'query retrieved no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual(len(r), 0, + 'cursor.fetchmany should return an empty sequence if ' + 'query retrieved no rows' + ) + self.assertTrue(cur.rowcount in (-1, 0)) finally: con.close() @@ -583,36 +612,36 @@ def test_fetchall(self): # cursor.fetchall should raise an Error if called # after executing a a statement that cannot return rows - self.assertRaises(self.driver.Error,cur.fetchall) + self.assertRaises(self.driver.Error, cur.fetchall) cur.execute('select name from %sbooze' % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - self.assertEqual(len(rows),len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) + self.assertEqual(len(rows), len(self.samples), + 'cursor.fetchall did not retrieve all rows' + ) rows = [r[0] for r in rows] rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'cursor.fetchall retrieved incorrect rows' - ) + for i in range(0, len(self.samples)): + self.assertEqual(rows[i], self.samples[i], + 'cursor.fetchall retrieved incorrect rows' + ) rows = cur.fetchall() self.assertEqual( - len(rows),0, + len(rows), 0, 'cursor.fetchall should return an empty list if called ' 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,len(self.samples))) + ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) self.executeDDL2(cur) cur.execute('select name from %sbarflys' % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,0)) - self.assertEqual(len(rows),0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) + self.assertTrue(cur.rowcount in (-1, 0)) + self.assertEqual(len(rows), 0, + 'cursor.fetchall should return an empty list if ' + 'a select query returns no rows' + ) finally: con.close() @@ -626,92 +655,90 @@ def test_mixedfetch(self): cur.execute(sql) cur.execute('select name from %sbooze' % self.table_prefix) - rows1 = cur.fetchone() + rows1 = cur.fetchone() rows23 = cur.fetchmany(2) - rows4 = cur.fetchone() + rows4 = cur.fetchone() rows56 = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows23),2, - 'fetchmany returned incorrect number of rows' - ) - self.assertEqual(len(rows56),2, - 'fetchall returned incorrect number of rows' - ) + self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual(len(rows23), 2, + 'fetchmany returned incorrect number of rows' + ) + self.assertEqual(len(rows56), 2, + 'fetchall returned incorrect number of rows' + ) rows = [rows1[0]] - rows.extend([rows23[0][0],rows23[1][0]]) + rows.extend([rows23[0][0], rows23[1][0]]) rows.append(rows4[0]) - rows.extend([rows56[0][0],rows56[1][0]]) + rows.extend([rows56[0][0], rows56[1][0]]) rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved or inserted' - ) + for i in range(0, len(self.samples)): + self.assertEqual(rows[i], self.samples[i], + 'incorrect data retrieved or inserted' + ) finally: con.close() def help_nextset_setUp(self, cur): - """Should create a procedure called deleteme + ''' Should create a procedure called deleteme that returns two result sets, first the - number of rows in booze then "name from booze" - """ - if False: - sql = """ - create procedure deleteme as - begin - select count(*) from booze - select name from booze - end - """ - cur.execute(sql) - else: - raise NotImplementedError('Helper not implemented') + number of rows in booze then "name from booze" + ''' + raise NotImplementedError('Helper not implemented') + # sql=""" + # create procedure deleteme as + # begin + # select count(*) from booze + # select name from booze + # end + # """ + # cur.execute(sql) def help_nextset_tearDown(self, cur): - """If cleaning up is needed after nextSetTest""" - if False: - cur.execute("drop procedure deleteme") - else: - - raise NotImplementedError('Helper not implemented') + 'If cleaning up is needed after nextSetTest' + raise NotImplementedError('Helper not implemented') + # cur.execute("drop procedure deleteme") def test_nextset(self): con = self._connect() try: cur = con.cursor() - if not hasattr(cur,'nextset'): + if not hasattr(cur, 'nextset'): return try: self.executeDDL1(cur) - sql=self._populate() + sql = self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) cur.callproc('deleteme') - numberofrows=cur.fetchone() - assert numberofrows[0]== len(self.samples) + numberofrows = cur.fetchone() + assert numberofrows[0] == len(self.samples) assert cur.nextset() - names=cur.fetchall() + names = cur.fetchall() assert len(names) == len(self.samples) - s=cur.nextset() - assert s == None,'No more return sets, should return None' + s = cur.nextset() + assert s == None, 'No more return sets, should return None' finally: self.help_nextset_tearDown(cur) finally: con.close() + def test_nextset(self): + raise NotImplementedError('Drivers need to override this test') + def test_arraysize(self): - """Not much here - rest of the tests for this are in test_fetchmany""" + # Not much here - rest of the tests for this are in test_fetchmany con = self._connect() try: cur = con.cursor() - self.assertTrue(hasattr(cur,'arraysize'), - 'cursor.arraysize must be defined' - ) + self.assertTrue(hasattr(cur, 'arraysize'), + 'cursor.arraysize must be defined' + ) finally: con.close() @@ -719,86 +746,91 @@ def test_setinputsizes(self): con = self._connect() try: cur = con.cursor() - cur.setinputsizes( (25,) ) - self._paraminsert(cur) # Make sure cursor still works + cur.setinputsizes((25,)) + self._paraminsert(cur) # Make sure cursor still works finally: con.close() def test_setoutputsize_basic(self): - """Basic test is to make sure setoutputsize doesn't blow up""" + # Basic test is to make sure setoutputsize doesn't blow up con = self._connect() try: cur = con.cursor() cur.setoutputsize(1000) - cur.setoutputsize(2000,0) - self._paraminsert(cur) # Make sure the cursor still works + cur.setoutputsize(2000, 0) + self._paraminsert(cur) # Make sure the cursor still works finally: con.close() def test_setoutputsize(self): - """Real test for setoutputsize is driver dependant""" - raise NotImplementedError('Driver needs to override this test') + # Real test for setoutputsize is driver dependant + raise NotImplementedError('Driver needed to override this test') def test_None(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) - cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) - cur.execute('select name from %sbooze' % self.table_prefix) + self.executeDDL2(cur) + # inserting NULL to the second column, because some drivers might + # need the first one to be primary key, which means it needs + # to have a non-NULL value + cur.execute("%s into %sbarflys values ('a', NULL)" % ( + self.insert, self.table_prefix)) + cur.execute('select drink from %sbarflys' % self.table_prefix) r = cur.fetchall() - self.assertEqual(len(r),1) - self.assertEqual(len(r[0]),1) - self.assertEqual(r[0][0],None,'NULL value not returned as None') + self.assertEqual(len(r), 1) + self.assertEqual(len(r[0]), 1) + self.assertEqual(r[0][0], None, 'NULL value not returned as None') finally: con.close() def test_Date(self): - d1 = self.driver.Date(2002,12,25) - d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + d1 = self.driver.Date(2002, 12, 25) + d2 = self.driver.DateFromTicks( + time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(d1),str(d2)) def test_Time(self): - t1 = self.driver.Time(13,45,30) - t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + t1 = self.driver.Time(13, 45, 30) + t2 = self.driver.TimeFromTicks( + time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Timestamp(self): - t1 = self.driver.Timestamp(2002,12,25,13,45,30) + t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) t2 = self.driver.TimestampFromTicks( - time.mktime((2002,12,25,13,45,30,0,0,0)) - ) + time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) + ) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Binary(self): - b = self.driver.Binary(b'Something') - b = self.driver.Binary(b'') + b = self.driver.Binary(str2bytes('Something')) + b = self.driver.Binary(str2bytes('')) def test_STRING(self): - self.assertTrue(hasattr(self.driver,'STRING'), - 'module.STRING must be defined' - ) + self.assertTrue(hasattr(self.driver, 'STRING'), + 'module.STRING must be defined' + ) def test_BINARY(self): - self.assertTrue(hasattr(self.driver,'BINARY'), - 'module.BINARY must be defined.' - ) + self.assertTrue(hasattr(self.driver, 'BINARY'), + 'module.BINARY must be defined.' + ) def test_NUMBER(self): - self.assertTrue(hasattr(self.driver,'NUMBER'), - 'module.NUMBER must be defined.' - ) + self.assertTrue(hasattr(self.driver, 'NUMBER'), + 'module.NUMBER must be defined.' + ) def test_DATETIME(self): - self.assertTrue(hasattr(self.driver,'DATETIME'), - 'module.DATETIME must be defined.' - ) + self.assertTrue(hasattr(self.driver, 'DATETIME'), + 'module.DATETIME must be defined.' + ) def test_ROWID(self): - self.assertTrue(hasattr(self.driver,'ROWID'), - 'module.ROWID must be defined.' - ) - + self.assertTrue(hasattr(self.driver, 'ROWID'), + 'module.ROWID must be defined.' + ) From 7b95fc7d3f18365872a907640076f5a1f2147979 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Jun 2020 19:08:41 +0200 Subject: [PATCH 019/194] Make all Python code PEP8 compliant using flake8 Also make all files pass PyCharm inspections. Fix minor issues that were detected by flake8 and PyCharm. --- .flake8 | 4 + docs/contents/changelog.rst | 6 +- pg.py | 110 +++--- pgdb.py | 100 ++++-- pgmodule.c | 4 +- py3c.h | 2 + tests/dbapi20.py | 537 ++++++++++++++--------------- tests/test_classic.py | 129 +++---- tests/test_classic_connection.py | 98 ++++-- tests/test_classic_dbwrapper.py | 349 +++++++++++-------- tests/test_classic_functions.py | 8 +- tests/test_classic_largeobj.py | 6 +- tests/test_classic_notification.py | 11 +- tests/test_dbapi20.py | 124 ++++--- tests/test_dbapi20_copy.py | 14 +- tests/test_tutorial.py | 4 +- tox.ini | 8 +- 17 files changed, 820 insertions(+), 694 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..3f6e0a3c --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +ignore = F403,F405,W503 +exclude = .git,.tox,.venv,build,dist,docs +max-line-length = 79 diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 7e01520a..af038add 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -3,7 +3,8 @@ ChangeLog Version 5.2 (to be released) ---------------------------- -- We now require Python version 2.7 or 3.5 and newer +- We now require Python version 2.7 or 3.5 and newer. +- All Python code is now tested with flake8 and made PEP8 compliant. - Changes to the classic PyGreSQL module (pg): - New module level function `get_pqlib_version()` that gets the version of the pqlib used by PyGreSQL (needs PostgreSQL >= 9.1 on the client). @@ -11,10 +12,11 @@ Version 5.2 (to be released) the query (needs PostgreSQL >= 12 on the client). - Changes to the DB-API 2 module (pgdb): + - When using Python 2, errors are now derived from StandardError + instead of Exception, as required by the DB-API 2 compliance test. - Connection arguments containing single quotes caused problems (reported and fixed by Tyler Ramer and Jamie McAtamney). - Version 5.1.2 (2020-04-19) -------------------------- - Improved handling of build_ext options for disabling certain features. diff --git a/pg.py b/pg.py index 33a3ae75..94cfa653 100644 --- a/pg.py +++ b/pg.py @@ -82,12 +82,12 @@ from json import loads as jsondecode, dumps as jsonencode from uuid import UUID -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable long except NameError: # Python >= 3.0 long = int -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable basestring except NameError: # Python >= 3.0 basestring = (str, bytes) @@ -96,13 +96,15 @@ from functools import lru_cache except ImportError: # Python < 3.2 from functools import update_wrapper - try: + try: # noinspection PyCompatibility from _thread import RLock except ImportError: class RLock: # for builds without threads - def __enter__(self): pass + def __enter__(self): + pass - def __exit__(self, exctype, excinst, exctb): pass + def __exit__(self, exctype, excinst, exctb): + pass def lru_cache(maxsize=128): """Simplified functools.lru_cache decorator for one argument.""" @@ -139,9 +141,9 @@ def wrapper(arg): link = get(arg) if link is not None: root = root_full[0] - prev, next, _arg, res = link - prev[1] = next - next[0] = prev + prv, nxt, _arg, res = link + prv[1] = nxt + nxt[0] = prv last = root[0] last[1] = root[0] = link link[0] = last @@ -158,7 +160,7 @@ def wrapper(arg): oldroot[3] = res root = root_full[0] = oldroot[1] oldarg = root[2] - oldres = root[3] # keep reference + oldres = root[3] # noqa F481 (keep reference) root[2] = root[3] = None del cache[oldarg] cache[arg] = oldroot @@ -178,7 +180,7 @@ def wrapper(arg): # Auxiliary classes and functions that are independent from a DB connection: -try: +try: # noinspection PyUnresolvedReferences from inspect import signature except ImportError: # Python < 3.3 from inspect import getargspec @@ -254,16 +256,18 @@ def _oid_key(table): class _SimpleTypes(dict): """Dictionary mapping pg_type names to simple type names.""" - _types = {'bool': 'bool', + _types = { + 'bool': 'bool', 'bytea': 'bytea', 'date': 'date interval time timetz timestamp timestamptz' - ' abstime reltime', # these are very old + ' abstime reltime', # these are very old 'float': 'float4 float8', 'int': 'cid int2 int4 int8 oid xid', 'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid', 'num': 'numeric', 'money': 'money', 'text': 'bpchar char name text varchar'} + # noinspection PyMissingConstructor def __init__(self): for typ, keys in self._types.items(): for key in keys.split(): @@ -274,6 +278,7 @@ def __init__(self): def __missing__(key): return 'text' + _simpletypes = _SimpleTypes() @@ -299,6 +304,7 @@ def add(self, value, typ=None): If this is a literal value, it will be returned as is. Otherwise, a placeholder will be returned and the parameter list will be augmented. """ + # noinspection PyUnresolvedReferences value = self.adapt(value, typ) if isinstance(value, Literal): return value @@ -465,7 +471,7 @@ def _adapt_num_array(cls, v): return str(v) _adapt_int_array = _adapt_float_array = _adapt_money_array = \ - _adapt_num_array + _adapt_num_array def _adapt_bytea_array(self, v): """Adapt a bytea array parameter.""" @@ -548,6 +554,7 @@ def simple_type(name): def get_simple_name(typ): """Get the simple name of a database type.""" if isinstance(typ, DbType): + # noinspection PyUnresolvedReferences return typ.simple return _simpletypes[typ] @@ -601,9 +608,10 @@ def guess_simple_type(cls, value): simple_type = cls.simple_type guess = cls.guess_simple_type + # noinspection PyUnusedLocal def get_attnames(self): return AttrDict((str(n + 1), simple_type(guess(v))) - for n, v in enumerate(value)) + for n, v in enumerate(value)) typ = simple_type('record') typ._get_attnames = get_attnames @@ -631,7 +639,9 @@ def adapt_inline(self, value, nested=False): if bytes is not str: # Python >= 3.0 value = value.decode('ascii') elif isinstance(value, Json): + # noinspection PyUnresolvedReferences if value.encode: + # noinspection PyUnresolvedReferences return value.encode() value = self.db.encode_json(value) elif isinstance(value, (datetime, date, time, timedelta)): @@ -689,8 +699,8 @@ def format_query(self, command, values=None, types=None, inline=False): else: add = params.add if types: - if (not isinstance(types, (list, tuple)) or - len(types) != len(values)): + if (not isinstance(types, (list, tuple)) + or len(types) != len(values)): raise TypeError('The values and types do not match') literals = [add(value, typ) for value, typ in zip(values, types)] @@ -850,7 +860,7 @@ def cast_timestamptz(value, connection): if len(value[3]) > 4: return datetime.max fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] + '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] value, tz = value[:-1], value[-1] else: if fmt.startswith('%Y-'): @@ -956,7 +966,7 @@ def cast_interval(value): raise ValueError('Cannot parse interval: %s' % value) days += 365 * years + 30 * mons return timedelta(days=days, hours=hours, minutes=mins, - seconds=secs, microseconds=usecs) + seconds=secs, microseconds=usecs) class Typecasts(dict): @@ -973,7 +983,8 @@ class Typecasts(dict): # the default cast functions # (str functions are ignored but have been added for faster access) - defaults = {'char': str, 'bpchar': str, 'name': str, + defaults = { + 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'bool': cast_bool, 'bytea': unescape_bytea, 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, @@ -1084,6 +1095,7 @@ def set_default(cls, typ, cast): defaults[t] = cast defaults.pop('_%s' % t, None) + # noinspection PyMethodMayBeStatic,PyUnusedLocal def get_attnames(self, typ): """Return the fields for the given record type. @@ -1091,6 +1103,7 @@ def get_attnames(self, typ): """ return {} + # noinspection PyMethodMayBeStatic def dateformat(self): """Return the current date format. @@ -1112,6 +1125,7 @@ def create_record_cast(self, name, fields, casts): record = namedtuple(name, fields) def cast(v): + # noinspection PyArgumentList return record(*cast_record(v, casts)) return cast @@ -1149,6 +1163,7 @@ class DbType(str): @property def attnames(self): """Get names and types of the fields of a composite type.""" + # noinspection PyUnresolvedReferences return self._get_attnames(self) @@ -1185,7 +1200,7 @@ def __init__(self, db): " WHERE oid OPERATOR(pg_catalog.=) %s::regtype") def add(self, oid, pgtype, regtype, - typtype, category, delim, relid): + typtype, category, delim, relid): """Create a PostgreSQL type name with additional info.""" if oid in self: return self[oid] @@ -1268,6 +1283,7 @@ def typecast(self, value, typ): # by default. Since creating namedtuple classes is a somewhat expensive # operation, we cache up to 1024 of these classes by default. +# noinspection PyUnresolvedReferences @lru_cache(maxsize=1024) def _row_factory(names): """Get a namedtuple factory for row results with the given names.""" @@ -1283,6 +1299,7 @@ def set_row_factory_size(maxsize): If maxsize is set to None, the cache can grow without bound. """ + # noinspection PyGlobalUndefined global _row_factory _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) @@ -1364,7 +1381,7 @@ class NotificationHandler(object): """A PostgreSQL client-side asynchronous notification handler.""" def __init__(self, db, event, callback=None, - arg_dict=None, timeout=None, stop_event=None): + arg_dict=None, timeout=None, stop_event=None): """Initialize the notification handler. You must pass a PyGreSQL database connection, the name of an @@ -1459,6 +1476,7 @@ def __call__(self): if not poll: rlist = [self.db.fileno()] while self.listening: + # noinspection PyUnboundLocalVariable if poll or select.select(rlist, [], [], self.timeout)[0]: while self.listening: notice = self.db.getnotify() @@ -1484,7 +1502,7 @@ def __call__(self): def pgnotify(*args, **kw): """Same as NotificationHandler, under the traditional name.""" warnings.warn("pgnotify is deprecated, use NotificationHandler instead", - DeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) return NotificationHandler(*args, **kw) @@ -1513,6 +1531,7 @@ def __init__(self, *args, **kw): db = db.db else: try: + # noinspection PyUnresolvedReferences db = db._cnx except AttributeError: pass @@ -1551,11 +1570,12 @@ def __init__(self, *args, **kw): " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass AND %s" " AND NOT a.attisdropped ORDER BY a.attnum") db.set_cast_hook(self.dbtypes.typecast) - self.debug = None # For debugging scripts, this can be set - # * to a string format specification (e.g. in CGI set to "%s
"), - # * to a file object to write debug statements or - # * to a callable object which takes a string argument - # * to any other true value to just print debug statements + # For debugging scripts, self.debug can be set + # * to a string format specification (e.g. in CGI set to "%s
"), + # * to a file object to write debug statements or + # * to a callable object which takes a string argument + # * to any other true value to just print debug statements + self.debug = None def __getattr__(self, name): # All undefined members are same as in underlying connection: @@ -1610,6 +1630,7 @@ def _do_debug(self, *args): if isinstance(self.debug, basestring): print(self.debug % s) elif hasattr(self.debug, 'write'): + # noinspection PyCallingNonCallable self.debug.write(s + '\n') elif callable(self.debug): self.debug(s) @@ -1633,7 +1654,8 @@ def _make_bool(d): """Get boolean value corresponding to d.""" return bool(d) if get_bool() else ('t' if d else 'f') - def _list_params(self, params): + @staticmethod + def _list_params(params): """Create a human readable parameter list.""" return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1)) @@ -1643,11 +1665,13 @@ def _list_params(self, params): # so we define unescape_bytea as a method as well unescape_bytea = staticmethod(unescape_bytea) - def decode_json(self, s): + @staticmethod + def decode_json(s): """Decode a JSON string coming from the database.""" return (get_jsondecode() or jsondecode)(s) - def encode_json(self, d): + @staticmethod + def encode_json(d): """Encode a JSON string for use within SQL.""" return jsonencode(d) @@ -1991,7 +2015,7 @@ def pkey(self, table, composite=False, flush=False): " AND NOT a.attisdropped" " WHERE i.indrelid OPERATOR(pg_catalog.=) %s::regclass" " AND i.indisprimary ORDER BY a.attnum") % ( - _quote_if_unqualified('$1', table),) + _quote_if_unqualified('$1', table),) pkey = self.db.query(q, (table,)).getresult() if not pkey: raise KeyError('Table %s has no primary key' % table) @@ -2319,10 +2343,11 @@ def upsert(self, table, row=None, **kw): appear as keys in the dictionary are also updated like in the case keywords had been passed with the value True. - So if in the case of a conflict you want to update every column that - has been passed in the dictionary row, you would call upsert(table, row). - If you don't want to do anything in case of a conflict, i.e. leave - the existing row as it is, call upsert(table, row, **dict.fromkeys(row)). + So if in the case of a conflict you want to update every column + that has been passed in the dictionary row, you would call + upsert(table, row). If you don't want to do anything in case + of a conflict, i.e. leave the existing row as it is, call + upsert(table, row, **dict.fromkeys(row)). If you need more fine-grained control of what gets updated, you can also pass strings in the keyword parameters. These strings will @@ -2351,7 +2376,7 @@ def upsert(self, table, row=None, **kw): params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - names, values, updates = [], [], [] + names, values = [], [] for n in attnames: if n in row: names.append(col(n)) @@ -2378,8 +2403,7 @@ def upsert(self, table, row=None, **kw): ret = 'oid, *' if qoid else '*' q = ('INSERT INTO %s AS included (%s) VALUES (%s)' ' ON CONFLICT (%s) DO %s RETURNING %s') % ( - self._escape_qualified_name(table), names, values, - target, do, ret) + self._escape_qualified_name(table), names, values, target, do, ret) self._do_debug(q, params) try: q = self.db.query(q, params) @@ -2673,7 +2697,10 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, getrow = itemgetter(*rowind) else: rowind = rowind[0] - getrow = lambda row: (row[rowind],) + + def getrow(row): + return row[rowind], # tuple with one item + rowtuple = True rows = map(getrow, res) if keytuple or rowtuple: @@ -2682,13 +2709,14 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, if rowtuple: fields = [f for f in fields if f not in keyset] rows = _namediter(_MemoryQuery(rows, fields)) + # noinspection PyArgumentList return cls(zip(keys, rows)) def notification_handler(self, event, callback, arg_dict=None, timeout=None, stop_event=None): """Get notification handler that will run the given callback.""" - return NotificationHandler(self, - event, callback, arg_dict, timeout, stop_event) + return NotificationHandler(self, event, callback, + arg_dict, timeout, stop_event) # if run as script, print some information diff --git a/pgdb.py b/pgdb.py index 398c9ab1..9016c3cb 100644 --- a/pgdb.py +++ b/pgdb.py @@ -104,10 +104,10 @@ from datetime import date, time, datetime, timedelta, tzinfo from time import localtime -from decimal import Decimal +from decimal import Decimal as StdDecimal from uuid import UUID as Uuid from math import isnan, isinf -try: +try: # noinspection PyCompatibility from collections.abc import Iterable except ImportError: # Python < 3.3 from collections import Iterable @@ -116,17 +116,19 @@ from re import compile as regex from json import loads as jsondecode, dumps as jsonencode -try: # noinspection PyUnresolvedReferences +Decimal = StdDecimal + +try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable long except NameError: # Python >= 3.0 long = int -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable unicode except NameError: # Python >= 3.0 unicode = str -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable basestring except NameError: # Python >= 3.0 basestring = (str, bytes) @@ -135,13 +137,15 @@ from functools import lru_cache except ImportError: # Python < 3.2 from functools import update_wrapper - try: + try: # noinspection PyCompatibility from _thread import RLock except ImportError: class RLock: # for builds without threads - def __enter__(self): pass + def __enter__(self): + pass - def __exit__(self, exctype, excinst, exctb): pass + def __exit__(self, exctype, excinst, exctb): + pass def lru_cache(maxsize=128): """Simplified functools.lru_cache decorator for one argument.""" @@ -178,9 +182,9 @@ def wrapper(arg): link = get(arg) if link is not None: root = root_full[0] - prev, next, _arg, res = link - prev[1] = next - next[0] = prev + prv, nxt, _arg, res = link + prv[1] = nxt + nxt[0] = prv last = root[0] last[1] = root[0] = link link[0] = last @@ -197,7 +201,7 @@ def wrapper(arg): oldroot[3] = res root = root_full[0] = oldroot[1] oldarg = root[2] - oldres = root[3] # keep reference + oldres = root[3] # noqa F481 (keep reference) root[2] = root[3] = None del cache[oldarg] cache[arg] = oldroot @@ -215,7 +219,7 @@ def wrapper(arg): return decorator -### Module Constants +# *** Module Constants *** # compliant with DB API 2.0 apilevel = '2.0' @@ -231,9 +235,9 @@ def wrapper(arg): shortcutmethods = 1 -### Internal Type Handling +# *** Internal Type Handling *** -try: +try: # noinspection PyUnresolvedReferences from inspect import signature except ImportError: # Python < 3.3 from inspect import getargspec @@ -392,7 +396,7 @@ def cast_timestamp(value, connection): if len(value[3]) > 4: return datetime.max fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] + '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] else: if len(value[0]) > 10: return datetime.max @@ -415,7 +419,7 @@ def cast_timestamptz(value, connection): if len(value[3]) > 4: return datetime.max fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] + '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] value, tz = value[:-1], value[-1] else: if fmt.startswith('%Y-'): @@ -535,7 +539,8 @@ class Typecasts(dict): # the default cast functions # (str functions are ignored but have been added for faster access) - defaults = {'char': str, 'bpchar': str, 'name': str, + defaults = { + 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'bool': cast_bool, 'bytea': unescape_bytea, 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, @@ -647,6 +652,7 @@ def create_record_cast(self, name, fields, casts): record = namedtuple(name, fields) def cast(v): + # noinspection PyArgumentList return record(*cast_record(v, casts)) return cast @@ -707,6 +713,7 @@ def __missing__(self, typ): self[typ] = cast return cast + # noinspection PyMethodMayBeStatic,PyUnusedLocal def get_fields(self, typ): """Return the fields for the given record type. @@ -723,6 +730,7 @@ class TypeCode(str): but carry some additional information. """ + # noinspection PyShadowingBuiltins @classmethod def create(cls, oid, name, len, type, category, delim, relid): """Create a type code for a PostgreSQL data type.""" @@ -735,7 +743,8 @@ def create(cls, oid, name, len, type, category, delim, relid): self.relid = relid return self -FieldInfo = namedtuple('FieldInfo', ['name', 'type']) + +FieldInfo = namedtuple('FieldInfo', ('name', 'type')) class TypeCache(dict): @@ -785,6 +794,7 @@ def __missing__(self, key): type_code = TypeCode.create( int(res[0]), res[1], int(res[2]), res[3], res[4], res[5], int(res[6])) + # noinspection PyUnresolvedReferences self[type_code.oid] = self[str(type_code)] = type_code return type_code @@ -843,10 +853,12 @@ class _quotedict(dict): """ def __getitem__(self, key): + # noinspection PyUnresolvedReferences return self.quote(super(_quotedict, self).__getitem__(key)) -### Error Messages +# *** Error Messages *** + def _db_error(msg, cls=DatabaseError): """Return DatabaseError with empty sqlstate attribute.""" @@ -860,7 +872,8 @@ def _op_error(msg): return _db_error(msg, OperationalError) -### Row Tuples +# *** Row Tuples *** + _re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$') @@ -869,6 +882,7 @@ def _op_error(msg): # by default. Since creating namedtuple classes is a somewhat expensive # operation, we cache up to 1024 of these classes by default. +# noinspection PyUnresolvedReferences @lru_cache(maxsize=1024) def _row_factory(names): """Get a namedtuple factory for row results with the given names.""" @@ -884,11 +898,12 @@ def set_row_factory_size(maxsize): If maxsize is set to None, the cache can grow without bound. """ + # noinspection PyGlobalUndefined global _row_factory _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) -### Cursor Object +# *** Cursor Object *** class Cursor(object): """Cursor object.""" @@ -982,7 +997,7 @@ def _quote(self, value): return '(%s)' % (','.join(str(q(v)) for v in value),) except UnicodeEncodeError: # Python 2 with non-ascii values return u'(%s)' % (','.join(unicode(q(v)) for v in value),) - try: + try: # noinspection PyUnresolvedReferences value = value.__pg_repr__() except AttributeError: raise InterfaceError( @@ -1027,7 +1042,7 @@ def _make_description(self, info): size = mod precision = scale = None return CursorDescription(name, type_code, - None, size, precision, scale, None) + None, size, precision, scale, None) @property def description(self): @@ -1099,6 +1114,7 @@ def executemany(self, operation, seq_of_parameters): except DatabaseError: raise # database provides error message except Error as err: + # noinspection PyTypeChecker raise _db_error( "Error in '%s': '%s' " % (sql, err), InterfaceError) except Exception as err: @@ -1149,7 +1165,7 @@ def fetchmany(self, size=None, keep=False): raise _db_error(str(err)) typecast = self.type_cache.typecast return [self.row_factory([typecast(value, typ) - for typ, value in zip(self.coltypes, row)]) for row in result] + for typ, value in zip(self.coltypes, row)]) for row in result] def callproc(self, procname, parameters=None): """Call a stored database procedure with the given name. @@ -1167,8 +1183,9 @@ def callproc(self, procname, parameters=None): self.execute(query, parameters) return parameters + # noinspection PyShadowingBuiltins def copy_from(self, stream, table, - format=None, sep=None, null=None, size=None, columns=None): + format=None, sep=None, null=None, size=None, columns=None): """Copy data from an input stream to the specified table. The input stream can be a file-like object with a read() method or @@ -1305,8 +1322,9 @@ def chunks(): # return the cursor object, so you can chain operations return self + # noinspection PyShadowingBuiltins def copy_to(self, stream, table, - format=None, sep=None, null=None, decode=None, columns=None): + format=None, sep=None, null=None, decode=None, columns=None): """Copy data from the specified table to an output stream. The output stream can be a file-like object with a write() method or @@ -1404,6 +1422,7 @@ def copy(): # write the rows to the file-like input stream for row in copy(): + # noinspection PyUnboundLocalVariable write(row) # return the cursor object, so you can chain operations @@ -1468,12 +1487,12 @@ def build_row_factory(self): return _row_factory(tuple(names)) -CursorDescription = namedtuple('CursorDescription', - ['name', 'type_code', 'display_size', 'internal_size', - 'precision', 'scale', 'null_ok']) +CursorDescription = namedtuple('CursorDescription', ( + 'name', 'type_code', 'display_size', 'internal_size', + 'precision', 'scale', 'null_ok')) -### Connection Objects +# *** Connection Objects *** class Connection(object): """Connection object.""" @@ -1604,13 +1623,14 @@ def executemany(self, operation, param_seq): return cursor -### Module Interface +# *** Module Interface *** _connect = connect + def connect(dsn=None, - user=None, password=None, - host=None, database=None, **kwargs): + user=None, password=None, + host=None, database=None, **kwargs): """Connect to a database.""" # first get params from DSN dbport = -1 @@ -1666,11 +1686,12 @@ def connect(dsn=None, dbname.append('%s=%s' % (kw, value)) dbname = ' '.join(dbname) # open the connection + # noinspection PyArgumentList cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) return Connection(cnx) -### Types Handling +# *** Types Handling *** class Type(frozenset): """Type class for a couple of PostgreSQL data types. @@ -1722,6 +1743,7 @@ class RecordType: def __eq__(self, other): if isinstance(other, TypeCode): + # noinspection PyUnresolvedReferences return other.type == 'c' elif isinstance(other, basestring): return other == 'record' @@ -1730,6 +1752,7 @@ def __eq__(self, other): def __ne__(self, other): if isinstance(other, TypeCode): + # noinspection PyUnresolvedReferences return other.type != 'c' elif isinstance(other, basestring): return other != 'record' @@ -1786,9 +1809,10 @@ def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None): def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0, - tzinfo=None): + tzinfo=None): """Construct an object holding a time stamp value.""" - return datetime(year, month, day, hour, minute, second, microsecond, tzinfo) + return datetime(year, month, day, hour, minute, second, microsecond, + tzinfo) def DateFromTicks(ticks): @@ -1815,7 +1839,7 @@ class Binary(bytes): def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0): """Construct an object holding a time interval value.""" return timedelta(days, hours=hours, minutes=minutes, seconds=seconds, - microseconds=microseconds) + microseconds=microseconds) Uuid = Uuid # Construct an object holding a UUID value diff --git a/pgmodule.c b/pgmodule.c index 20c47993..a3d68c54 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -1258,10 +1258,10 @@ MODULE_INIT_FUNC(_pg) dict = PyModule_GetDict(mod); /* Exceptions as defined by DB-API 2.0 */ - Error = PyErr_NewException("pg.Error", PyExc_Exception, NULL); + Error = PyErr_NewException("pg.Error", PyExc_StandardError, NULL); PyDict_SetItemString(dict, "Error", Error); - Warning = PyErr_NewException("pg.Warning", PyExc_Exception, NULL); + Warning = PyErr_NewException("pg.Warning", PyExc_StandardError, NULL); PyDict_SetItemString(dict, "Warning", Warning); InterfaceError = PyErr_NewException( diff --git a/py3c.h b/py3c.h index 63a3222a..c137b191 100644 --- a/py3c.h +++ b/py3c.h @@ -57,6 +57,8 @@ #define Py_TPFLAGS_HAVE_ITER 0 // not needed in Python 3 +#define PyExc_StandardError PyExc_Exception // exists only in Python 2 + #else /***** Python 2 *****/ diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 37027183..2bb7e2b0 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -1,3 +1,4 @@ +#!/usr/bin/python """Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. @@ -9,43 +10,45 @@ import time try: # noinspection PyUnresolvedReferences - _BaseException = StandardError -except Exception: # Python 2 + _BaseException = StandardError # noqa: F821 +except NameError: # Python >= 3.0 _BaseException = Exception -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences unicode -except NameError: # Python 3 +except NameError: # Python >= 3.0 unicode = str def str2bytes(sval): if str is not unicode and isinstance(sval, str): + # noinspection PyUnresolvedReferences sval = sval.decode("latin1") return sval.encode("latin1") # python 3 make unicode into bytes class DatabaseAPI20Test(unittest.TestCase): - ''' Test a database self.driver for DB API 2.0 compatibility. - This implementation tests Gadfly, but the TestCase - is structured so that other self.drivers can subclass this - test case to ensure compliance with the DB-API. It is - expected that this TestCase may be expanded in the future - if ambiguities or edge conditions are discovered. + """Test a database self.driver for DB API 2.0 compatibility. + + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. - The 'Optional Extensions' are not yet being tested. + The 'Optional Extensions' are not yet being tested. - self.drivers should subclass this test, overriding setUp, tearDown, - self.driver, connect_args and connect_kw_args. Class specification - should be as follows: + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: - import dbapi20 - class mytest(dbapi20.DatabaseAPI20Test): - [...] + import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] - Don't 'import DatabaseAPI20Test from dbapi20', or you will - confuse the unit tester - just 'import dbapi20'. - ''' + Don't 'import DatabaseAPI20Test from dbapi20', or you will + confuse the unit tester - just 'import dbapi20'. + """ # The self.driver module. This should be the module where the 'connect' # method is to be found @@ -54,13 +57,14 @@ class mytest(dbapi20.DatabaseAPI20Test): connect_kw_args = {} # Keyword arguments for connect table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables - ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix - ddl2 = 'create table %sbarflys (name varchar(20), drink varchar(30))' % table_prefix - xddl1 = 'drop table %sbooze' % table_prefix - xddl2 = 'drop table %sbarflys' % table_prefix + ddl1 = 'create table %sbooze (name varchar(20))' % (table_prefix,) + ddl2 = 'create table %sbarflys (name varchar(20), drink varchar(30))' % ( + table_prefix,) + xddl1 = 'drop table %sbooze' % (table_prefix,) + xddl2 = 'drop table %sbarflys' % (table_prefix,) insert = 'insert' - lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + lowerfunc = 'lower' # Name of stored procedure to convert str to lowercase # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. @@ -71,16 +75,20 @@ def executeDDL2(self, cursor): cursor.execute(self.ddl2) def setUp(self): - ''' self.drivers should override this method to perform required setup - if any is necessary, such as creating the database. - ''' + """Set up test fixture. + + self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. + """ pass def tearDown(self): - ''' self.drivers should override this method to perform required cleanup - if any is necessary, such as deleting the test database. - The default drops the tables that may be created. - ''' + """Tear down test fixture. + + self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. + """ try: con = self._connect() try: @@ -125,7 +133,7 @@ def test_threadsafety(self): # Must exist threadsafety = self.driver.threadsafety # Must be a valid value - self.assertTrue(threadsafety in (0, 1, 2, 3)) + self.assertIn(threadsafety, (0, 1, 2, 3)) except AttributeError: self.fail("Driver doesn't define threadsafety") @@ -134,25 +142,25 @@ def test_paramstyle(self): # Must exist paramstyle = self.driver.paramstyle # Must be a valid value - self.assertTrue(paramstyle in ( - 'qmark', 'numeric', 'named', 'format', 'pyformat' - )) + self.assertIn(paramstyle, ( + 'qmark', 'numeric', 'named', 'format', 'pyformat')) except AttributeError: self.fail("Driver doesn't define paramstyle") def test_Exceptions(self): # Make sure required exceptions exist, and are in the # defined hierarchy. - self.assertTrue(issubclass(self.driver.Warning, _BaseException)) - self.assertTrue(issubclass(self.driver.Error, Exception)) - - self.assertTrue(issubclass(self.driver.InterfaceError, self.driver.Error)) - self.assertTrue(issubclass(self.driver.DatabaseError, self.driver.Error)) - self.assertTrue(issubclass(self.driver.OperationalError, self.driver.Error)) - self.assertTrue(issubclass(self.driver.IntegrityError, self.driver.Error)) - self.assertTrue(issubclass(self.driver.InternalError, self.driver.Error)) - self.assertTrue(issubclass(self.driver.ProgrammingError, self.driver.Error)) - self.assertTrue(issubclass(self.driver.NotSupportedError, self.driver.Error)) + sub = issubclass + self.assertTrue(sub(self.driver.Warning, _BaseException)) + self.assertTrue(sub(self.driver.Error, _BaseException)) + + self.assertTrue(sub(self.driver.InterfaceError, self.driver.Error)) + self.assertTrue(sub(self.driver.DatabaseError, self.driver.Error)) + self.assertTrue(sub(self.driver.OperationalError, self.driver.Error)) + self.assertTrue(sub(self.driver.IntegrityError, self.driver.Error)) + self.assertTrue(sub(self.driver.InternalError, self.driver.Error)) + self.assertTrue(sub(self.driver.ProgrammingError, self.driver.Error)) + self.assertTrue(sub(self.driver.NotSupportedError, self.driver.Error)) def test_ExceptionsAsConnectionAttributes(self): # OPTIONAL EXTENSION @@ -187,6 +195,7 @@ def test_rollback(self): # the documented exception if hasattr(con, 'rollback'): try: + # noinspection PyCallingNonCallable con.rollback() except self.driver.NotSupportedError: pass @@ -195,6 +204,7 @@ def test_cursor(self): con = self._connect() try: cur = con.cursor() + self.assertIsNotNone(cur) finally: con.close() @@ -222,31 +232,31 @@ def test_description(self): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.description, None, - 'cursor.description should be none after executing a ' - 'statement that can return no rows (such as DDL)' - ) + self.assertIsNone( + cur.description, + 'cursor.description should be none after executing a' + ' statement that can return no rows (such as DDL)') cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(len(cur.description), 1, - 'cursor.description describes too many columns' - ) - self.assertEqual(len(cur.description[0]), 7, - 'cursor.description[x] tuples must have 7 elements' - ) - self.assertEqual(cur.description[0][0].lower(), 'name', - 'cursor.description[x][0] must return column name' - ) - self.assertEqual(cur.description[0][1], self.driver.STRING, - 'cursor.description[x][1] must return column type. Got %r' - % cur.description[0][1] - ) + self.assertEqual( + len(cur.description), 1, + 'cursor.description describes too many columns') + self.assertEqual( + len(cur.description[0]), 7, + 'cursor.description[x] tuples must have 7 elements') + self.assertEqual( + cur.description[0][0].lower(), 'name', + 'cursor.description[x][0] must return column name') + self.assertEqual( + cur.description[0][1], self.driver.STRING, + 'cursor.description[x][1] must return column type. Got %r' + % cur.description[0][1]) # Make sure self.description gets reset self.executeDDL2(cur) - self.assertEqual(cur.description, None, - 'cursor.description not being set to None when executing ' - 'no-result statements (eg. DDL)' - ) + self.assertIsNone( + cur.description, + 'cursor.description not being set to None when executing' + ' no-result statements (eg. DDL)') finally: con.close() @@ -255,27 +265,27 @@ def test_rowcount(self): try: cur = con.cursor() self.executeDDL1(cur) - self.assertTrue(cur.rowcount in (-1, 0), # Bug #543885 - 'cursor.rowcount should be -1 or 0 after executing no-result ' - 'statements' - ) + self.assertIn( + cur.rowcount, (-1, 0), # Bug #543885 + 'cursor.rowcount should be -1 or 0 after executing no-result' + ' statements') cur.execute("%s into %sbooze values ('Victoria Bitter')" % ( self.insert, self.table_prefix )) - self.assertTrue(cur.rowcount in (-1, 1), - 'cursor.rowcount should == number or rows inserted, or ' - 'set to -1 after executing an insert statement' - ) + self.assertIn( + cur.rowcount, (-1, 1), + 'cursor.rowcount should == number or rows inserted, or' + ' set to -1 after executing an insert statement') cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1, 1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) + self.assertIn( + cur.rowcount, (-1, 1), + 'cursor.rowcount should == number of rows returned, or' + ' set to -1 after executing a select statement') self.executeDDL2(cur) - self.assertTrue(cur.rowcount in (-1, 0), # Bug #543885 - 'cursor.rowcount should be -1 or 0 after executing no-result ' - 'statements' - ) + self.assertIn( + cur.rowcount, (-1, 0), # Bug #543885 + 'cursor.rowcount should be -1 or 0 after executing no-result' + ' statements') finally: con.close() @@ -286,17 +296,16 @@ def test_callproc(self): try: cur = con.cursor() if self.lower_func and hasattr(cur, 'callproc'): + # noinspection PyCallingNonCallable r = cur.callproc(self.lower_func, ('FOO',)) self.assertEqual(len(r), 1) self.assertEqual(r[0], 'FOO') r = cur.fetchall() self.assertEqual(len(r), 1, 'callproc produced no result set') - self.assertEqual(len(r[0]), 1, - 'callproc produced invalid result set' - ) - self.assertEqual(r[0][0], 'foo', - 'callproc produced invalid results' - ) + self.assertEqual( + len(r[0]), 1, 'callproc produced invalid result set') + self.assertEqual( + r[0][0], 'foo', 'callproc produced invalid results') finally: con.close() @@ -319,7 +328,7 @@ def test_non_idempotent_close(self): con = self._connect() con.close() # connection.close should raise an Error if called more than once - # !!! reasonable persons differ about the usefulness of this test and this feature !!! + # (the usefulness of this test and this feature is questionable) self.assertRaises(self.driver.Error, con.close) def test_execute(self): @@ -333,68 +342,69 @@ def test_execute(self): def _paraminsert(self, cur): self.executeDDL2(cur) cur.execute( - "%s into %sbarflys values ('Victoria Bitter', 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1, 1)) + "%s into %sbarflys values ('Victoria Bitter'," + " 'thi%%s :may ca%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix)) + self.assertIn(cur.rowcount, (-1, 1)) if self.driver.paramstyle == 'qmark': cur.execute( - "%s into %sbarflys values (?, 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), - ("Cooper's",) - ) + "%s into %sbarflys values (?," + " 'thi%%s :may ca%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix), + ("Cooper's",)) elif self.driver.paramstyle == 'numeric': cur.execute( - "%s into %sbarflys values (:1, 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), - ("Cooper's",) - ) + "%s into %sbarflys values (:1," + " 'thi%%s :may ca%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix), + ("Cooper's",)) elif self.driver.paramstyle == 'named': cur.execute( - "%s into %sbarflys values (:beer, 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), - {'beer': "Cooper's"} - ) + "%s into %sbarflys values (:beer," + " 'thi%%s :may ca%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix), + {'beer': "Cooper's"}) elif self.driver.paramstyle == 'format': cur.execute( - "%s into %sbarflys values (%%s, 'thi%%%%s :may ca%%%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), - ("Cooper's",) - ) + "%s into %sbarflys values (%%s," + " 'thi%%%%s :may ca%%%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix), + ("Cooper's",)) elif self.driver.paramstyle == 'pyformat': cur.execute( - "%s into %sbarflys values (%%(beer)s, 'thi%%%%s :may ca%%%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), - {'beer': "Cooper's"} - ) + "%s into %sbarflys values (%%(beer)s," + " 'thi%%%%s :may ca%%%%(u)se? troub:1e')" % ( + self.insert, self.table_prefix), + {'beer': "Cooper's"}) else: self.fail('Invalid paramstyle') - self.assertTrue(cur.rowcount in (-1, 1)) + self.assertIn(cur.rowcount, (-1, 1)) cur.execute('select name, drink from %sbarflys' % self.table_prefix) res = cur.fetchall() self.assertEqual(len(res), 2, 'cursor.fetchall returned too few rows') beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0], "Cooper's", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) - self.assertEqual(beers[1], "Victoria Bitter", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) + self.assertEqual( + beers[0], "Cooper's", + 'cursor.fetchall retrieved incorrect data, or data inserted' + ' incorrectly') + self.assertEqual( + beers[1], "Victoria Bitter", + 'cursor.fetchall retrieved incorrect data, or data inserted' + ' incorrectly') trouble = "thi%s :may ca%(u)se? troub:1e" - self.assertEqual(res[0][1], trouble, - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly. Got=%s, Expected=%s' % ( - repr(res[0][1]), repr(trouble))) - self.assertEqual(res[1][1], trouble, - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly. Got=%s, Expected=%s' % ( - repr(res[1][1]), repr(trouble) - )) + self.assertEqual( + res[0][1], trouble, + 'cursor.fetchall retrieved incorrect data, or data inserted' + ' incorrectly. Got=%s, Expected=%s' % ( + repr(res[0][1]), repr(trouble))) + self.assertEqual( + res[1][1], trouble, + 'cursor.fetchall retrieved incorrect data, or data inserted' + ' incorrectly. Got=%s, Expected=%s' % ( + repr(res[1][1]), repr(trouble))) def test_executemany(self): con = self._connect() @@ -406,45 +416,34 @@ def test_executemany(self): if self.driver.paramstyle == 'qmark': cur.executemany( '%s into %sbooze values (?)' % ( - self.insert, self.table_prefix), - largs - ) + self.insert, self.table_prefix), largs) elif self.driver.paramstyle == 'numeric': cur.executemany( '%s into %sbooze values (:1)' % ( - self.insert, self.table_prefix), - largs - ) + self.insert, self.table_prefix), largs) elif self.driver.paramstyle == 'named': cur.executemany( '%s into %sbooze values (:beer)' % ( - self.insert, self.table_prefix), - margs - ) + self.insert, self.table_prefix), margs) elif self.driver.paramstyle == 'format': cur.executemany( '%s into %sbooze values (%%s)' % ( - self.insert, self.table_prefix), - largs - ) + self.insert, self.table_prefix), largs) elif self.driver.paramstyle == 'pyformat': cur.executemany( '%s into %sbooze values (%%(beer)s)' % ( - self.insert, self.table_prefix - ), - margs - ) + self.insert, self.table_prefix), margs) else: self.fail('Unknown paramstyle') - self.assertTrue(cur.rowcount in (-1, 2), - 'insert using cursor.executemany set cursor.rowcount to ' - 'incorrect value %r' % cur.rowcount - ) + self.assertIn( + cur.rowcount, (-1, 2), + 'insert using cursor.executemany set cursor.rowcount to' + ' incorrect value %r' % cur.rowcount) cur.execute('select name from %sbooze' % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res), 2, - 'cursor.fetchall retrieved incorrect number of rows' - ) + self.assertEqual( + len(res), 2, + 'cursor.fetchall retrieved incorrect number of rows') beers = [res[0][0], res[1][0]] beers.sort() self.assertEqual(beers[0], "Boag's", 'incorrect data retrieved') @@ -462,19 +461,19 @@ def test_fetchone(self): self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after - # executing a query that cannnot return rows + # executing a query that cannot return rows self.executeDDL1(cur) self.assertRaises(self.driver.Error, cur.fetchone) cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(), None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1, 0)) + self.assertIsNone( + cur.fetchone(), + 'cursor.fetchone should return None if a query retrieves' + ' no rows') + self.assertIn(cur.rowcount, (-1, 0)) # cursor.fetchone should raise an Error if called after - # executing a query that cannnot return rows + # executing a query that cannot return rows cur.execute("%s into %sbooze values ('Victoria Bitter')" % ( self.insert, self.table_prefix )) @@ -482,16 +481,16 @@ def test_fetchone(self): cur.execute('select name from %sbooze' % self.table_prefix) r = cur.fetchone() - self.assertEqual(len(r), 1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0], 'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) - self.assertEqual(cur.fetchone(), None, - 'cursor.fetchone should return None if no more rows available' - ) - self.assertTrue(cur.rowcount in (-1, 1)) + self.assertEqual( + len(r), 1, + 'cursor.fetchone should have retrieved a single row') + self.assertEqual( + r[0], 'Victoria Bitter', + 'cursor.fetchone retrieved incorrect data') + self.assertIsNone( + cur.fetchone(), + 'cursor.fetchone should return None if no more rows available') + self.assertIn(cur.rowcount, (-1, 1)) finally: con.close() @@ -505,14 +504,11 @@ def test_fetchone(self): ] def _populate(self): - ''' Return a list of sql commands to setup the DB for the fetch - tests. - ''' + """Return a list of SQL commands to setup the DB for fetching tests.""" populate = [ "%s into %sbooze values ('%s')" % ( - self.insert, self.table_prefix, s) - for s in self.samples - ] + self.insert, self.table_prefix, s) + for s in self.samples] return populate def test_fetchmany(self): @@ -530,43 +526,43 @@ def test_fetchmany(self): cur.execute('select name from %sbooze' % self.table_prefix) r = cur.fetchmany() - self.assertEqual(len(r), 1, - 'cursor.fetchmany retrieved incorrect number of rows, ' - 'default of arraysize is one.' - ) + self.assertEqual( + len(r), 1, + 'cursor.fetchmany retrieved incorrect number of rows,' + ' default of arraysize is one.') cur.arraysize = 10 r = cur.fetchmany(3) # Should get 3 rows - self.assertEqual(len(r), 3, - 'cursor.fetchmany retrieved incorrect number of rows' - ) + self.assertEqual( + len(r), 3, + 'cursor.fetchmany retrieved incorrect number of rows') r = cur.fetchmany(4) # Should get 2 more - self.assertEqual(len(r), 2, - 'cursor.fetchmany retrieved incorrect number of rows' - ) + self.assertEqual( + len(r), 2, + 'cursor.fetchmany retrieved incorrect number of rows') r = cur.fetchmany(4) # Should be an empty sequence - self.assertEqual(len(r), 0, - 'cursor.fetchmany should return an empty sequence after ' - 'results are exhausted' - ) - self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual( + len(r), 0, + 'cursor.fetchmany should return an empty sequence after' + ' results are exhausted') + self.assertIn(cur.rowcount, (-1, 6)) # Same as above, using cursor.arraysize cur.arraysize = 4 cur.execute('select name from %sbooze' % self.table_prefix) r = cur.fetchmany() # Should get 4 rows - self.assertEqual(len(r), 4, - 'cursor.arraysize not being honoured by fetchmany' - ) + self.assertEqual( + len(r), 4, + 'cursor.arraysize not being honoured by fetchmany') r = cur.fetchmany() # Should get 2 more self.assertEqual(len(r), 2) r = cur.fetchmany() # Should be an empty sequence self.assertEqual(len(r), 0) - self.assertTrue(cur.rowcount in (-1, 6)) + self.assertIn(cur.rowcount, (-1, 6)) cur.arraysize = 6 cur.execute('select name from %sbooze' % self.table_prefix) rows = cur.fetchmany() # Should get all rows - self.assertTrue(cur.rowcount in (-1, 6)) + self.assertIn(cur.rowcount, (-1, 6)) self.assertEqual(len(rows), 6) self.assertEqual(len(rows), 6) rows = [r[0] for r in rows] @@ -574,25 +570,25 @@ def test_fetchmany(self): # Make sure we get the right data back out for i in range(0, 6): - self.assertEqual(rows[i], self.samples[i], - 'incorrect data retrieved by cursor.fetchmany' - ) + self.assertEqual( + rows[i], self.samples[i], + 'incorrect data retrieved by cursor.fetchmany') rows = cur.fetchmany() # Should return an empty list - self.assertEqual(len(rows), 0, - 'cursor.fetchmany should return an empty sequence if ' - 'called after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual( + len(rows), 0, + 'cursor.fetchmany should return an empty sequence if' + ' called after the whole result set has been fetched') + self.assertIn(cur.rowcount, (-1, 6)) self.executeDDL2(cur) cur.execute('select name from %sbarflys' % self.table_prefix) r = cur.fetchmany() # Should get empty sequence - self.assertEqual(len(r), 0, - 'cursor.fetchmany should return an empty sequence if ' - 'query retrieved no rows' - ) - self.assertTrue(cur.rowcount in (-1, 0)) + self.assertEqual( + len(r), 0, + 'cursor.fetchmany should return an empty sequence if' + ' query retrieved no rows') + self.assertIn(cur.rowcount, (-1, 0)) finally: con.close() @@ -616,32 +612,30 @@ def test_fetchall(self): cur.execute('select name from %sbooze' % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1, len(self.samples))) - self.assertEqual(len(rows), len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) - rows = [r[0] for r in rows] - rows.sort() + self.assertIn(cur.rowcount, (-1, len(self.samples))) + self.assertEqual( + len(rows), len(self.samples), + 'cursor.fetchall did not retrieve all rows') + rows = sorted(r[0] for r in rows) for i in range(0, len(self.samples)): - self.assertEqual(rows[i], self.samples[i], - 'cursor.fetchall retrieved incorrect rows' - ) + self.assertEqual( + rows[i], self.samples[i], + 'cursor.fetchall retrieved incorrect rows') rows = cur.fetchall() self.assertEqual( len(rows), 0, - 'cursor.fetchall should return an empty list if called ' - 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1, len(self.samples))) + 'cursor.fetchall should return an empty list if called' + ' after the whole result set has been fetched') + self.assertIn(cur.rowcount, (-1, len(self.samples))) self.executeDDL2(cur) cur.execute('select name from %sbarflys' % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1, 0)) - self.assertEqual(len(rows), 0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) + self.assertIn(cur.rowcount, (-1, 0)) + self.assertEqual( + len(rows), 0, + 'cursor.fetchall should return an empty list if' + ' a select query returns no rows') finally: con.close() @@ -659,13 +653,13 @@ def test_mixedfetch(self): rows23 = cur.fetchmany(2) rows4 = cur.fetchone() rows56 = cur.fetchall() - self.assertTrue(cur.rowcount in (-1, 6)) - self.assertEqual(len(rows23), 2, - 'fetchmany returned incorrect number of rows' - ) - self.assertEqual(len(rows56), 2, - 'fetchall returned incorrect number of rows' - ) + self.assertIn(cur.rowcount, (-1, 6)) + self.assertEqual( + len(rows23), 2, + 'fetchmany returned incorrect number of rows') + self.assertEqual( + len(rows56), 2, + 'fetchall returned incorrect number of rows') rows = [rows1[0]] rows.extend([rows23[0][0], rows23[1][0]]) @@ -673,19 +667,20 @@ def test_mixedfetch(self): rows.extend([rows56[0][0], rows56[1][0]]) rows.sort() for i in range(0, len(self.samples)): - self.assertEqual(rows[i], self.samples[i], - 'incorrect data retrieved or inserted' - ) + self.assertEqual( + rows[i], self.samples[i], + 'incorrect data retrieved or inserted') finally: con.close() def help_nextset_setUp(self, cur): - ''' Should create a procedure called deleteme - that returns two result sets, first the - number of rows in booze then "name from booze" - ''' + """Set up nextset test. + + Should create a procedure called deleteme that returns two result sets, + first the number of rows in booze, then "name from booze". + """ raise NotImplementedError('Helper not implemented') - # sql=""" + # sql = """ # create procedure deleteme as # begin # select count(*) from booze @@ -695,10 +690,14 @@ def help_nextset_setUp(self, cur): # cur.execute(sql) def help_nextset_tearDown(self, cur): - 'If cleaning up is needed after nextSetTest' + """Clean up after nextset test. + + If cleaning up is needed after nextSetTest. + """ raise NotImplementedError('Helper not implemented') # cur.execute("drop procedure deleteme") + # example test implementation only def test_nextset(self): con = self._connect() try: @@ -708,27 +707,27 @@ def test_nextset(self): try: self.executeDDL1(cur) - sql = self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) cur.callproc('deleteme') - numberofrows = cur.fetchone() - assert numberofrows[0] == len(self.samples) - assert cur.nextset() + number_of_rows = cur.fetchone() + self.assertEqual(number_of_rows[0], len(self.samples)) + self.assertTrue(cur.nextset()) names = cur.fetchall() - assert len(names) == len(self.samples) + self.assertEqual(len(names), len(self.samples)) s = cur.nextset() - assert s == None, 'No more return sets, should return None' + self.assertIsNone(s, 'No more return sets, should return None') finally: self.help_nextset_tearDown(cur) finally: con.close() - def test_nextset(self): + # noinspection PyRedeclaration + def test_nextset(self): # noqa: F811 raise NotImplementedError('Drivers need to override this test') def test_arraysize(self): @@ -736,9 +735,8 @@ def test_arraysize(self): con = self._connect() try: cur = con.cursor() - self.assertTrue(hasattr(cur, 'arraysize'), - 'cursor.arraysize must be defined' - ) + self.assertTrue( + hasattr(cur, 'arraysize'), 'cursor.arraysize must be defined') finally: con.close() @@ -775,12 +773,12 @@ def test_None(self): # need the first one to be primary key, which means it needs # to have a non-NULL value cur.execute("%s into %sbarflys values ('a', NULL)" % ( - self.insert, self.table_prefix)) + self.insert, self.table_prefix)) cur.execute('select drink from %sbarflys' % self.table_prefix) r = cur.fetchall() self.assertEqual(len(r), 1) self.assertEqual(len(r[0]), 1) - self.assertEqual(r[0][0], None, 'NULL value not returned as None') + self.assertIsNone(r[0][0], 'NULL value not returned as None') finally: con.close() @@ -789,14 +787,14 @@ def test_Date(self): d2 = self.driver.DateFromTicks( time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied - # self.assertEqual(str(d1),str(d2)) + self.assertEqual(str(d1), str(d2)) def test_Time(self): t1 = self.driver.Time(13, 45, 30) t2 = self.driver.TimeFromTicks( time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied - # self.assertEqual(str(t1),str(t2)) + self.assertEqual(str(t1), str(t2)) def test_Timestamp(self): t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) @@ -804,33 +802,28 @@ def test_Timestamp(self): time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) ) # Can we assume this? API doesn't specify, but it seems implied - # self.assertEqual(str(t1),str(t2)) + self.assertEqual(str(t1), str(t2)) def test_Binary(self): - b = self.driver.Binary(str2bytes('Something')) - b = self.driver.Binary(str2bytes('')) + self.driver.Binary(str2bytes('Something')) + self.driver.Binary(str2bytes('')) def test_STRING(self): self.assertTrue(hasattr(self.driver, 'STRING'), - 'module.STRING must be defined' - ) + 'module.STRING must be defined') def test_BINARY(self): self.assertTrue(hasattr(self.driver, 'BINARY'), - 'module.BINARY must be defined.' - ) + 'module.BINARY must be defined.') def test_NUMBER(self): self.assertTrue(hasattr(self.driver, 'NUMBER'), - 'module.NUMBER must be defined.' - ) + 'module.NUMBER must be defined.') def test_DATETIME(self): self.assertTrue(hasattr(self.driver, 'DATETIME'), - 'module.DATETIME must be defined.' - ) + 'module.DATETIME must be defined.') def test_ROWID(self): self.assertTrue(hasattr(self.driver, 'ROWID'), - 'module.ROWID must be defined.' - ) + 'module.ROWID must be defined.') diff --git a/tests/test_classic.py b/tests/test_classic.py index 397f3fd4..f106e0a4 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -5,7 +5,6 @@ import unittest -import sys from functools import partial from time import sleep from threading import Thread @@ -19,15 +18,15 @@ dbport = 5432 try: - from .LOCAL_PyGreSQL import * + from .LOCAL_PyGreSQL import * # noqa: F401 except (ImportError, ValueError): try: - from LOCAL_PyGreSQL import * + from LOCAL_PyGreSQL import * # noqa: F401 except ImportError: pass -def opendb(): +def open_db(): db = DB(dbname, dbhost, dbport) db.query("SET DATESTYLE TO 'ISO'") db.query("SET TIME ZONE 'EST5EDT'") @@ -36,58 +35,50 @@ def opendb(): db.query("SET STANDARD_CONFORMING_STRINGS=FALSE") return db -db = opendb() -for q in ( - "DROP TABLE _test1._test_schema", - "DROP TABLE _test2._test_schema", - "DROP SCHEMA _test1", - "DROP SCHEMA _test2", -): - try: - db.query(q) - except Exception: - pass -db.close() - class UtilityTest(unittest.TestCase): - def setUp(self): - """Setup test tables or empty them if they already exist.""" - db = opendb() - - for t in ('_test1', '_test2'): - try: - db.query("CREATE SCHEMA " + t) - except Error: - pass - try: - db.query( - "CREATE TABLE %s._test_schema " - "(%s int PRIMARY KEY)" % (t, t)) - except Error: - db.query("DELETE FROM %s._test_schema" % t) + @classmethod + def setupClass(cls): + """Drop test tables""" + db = open_db() try: - db.query( - "CREATE TABLE _test_schema " - "(_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)") - except Error: - db.query("DELETE FROM _test_schema") + db.query("DROP VIEW _test_vschema") + except Exception: + pass try: - db.query( - "CREATE VIEW _test_vschema AS " - "SELECT _test, 'abc'::text AS _test2 FROM _test_schema") - except Error: + db.query("DROP TABLE _test_schema") + except Exception: pass + db.query("CREATE TABLE _test_schema " + "(_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)") + db.query("CREATE VIEW _test_vschema AS" + " SELECT _test, 'abc'::text AS _test2 FROM _test_schema") + for t in ('_test1', '_test2'): + try: + db.query("DROP SCHEMA %s CASCADE") + except Exception: + pass + db.query("CREATE TABLE %s._test_schema" + " (%s int PRIMARY KEY)" % (t, t)) + db.close() - def test_invalidname(self): + def setUp(self): + """Setup test tables or empty them if they already exist.""" + db = open_db() + db.query("TRUNCATE TABLE _test_schema") + for t in ('_test1', '_test2'): + db.query("TRUNCATE TABLE %s._test_schema" % t) + db.close() + + def test_invalid_name(self): """Make sure that invalid table names are caught""" - db = opendb() + db = open_db() self.assertRaises(NotSupportedError, db.get_attnames, 'x.y.z') def test_schema(self): """Does it differentiate the same table name in different schemas""" - db = opendb() + db = open_db() # see if they differentiate the table names properly self.assertEqual( db.get_attnames('_test_schema'), @@ -107,7 +98,7 @@ def test_schema(self): ) def test_pkey(self): - db = opendb() + db = open_db() self.assertEqual(db.pkey('_test_schema'), '_test') self.assertEqual(db.pkey('public._test_schema'), '_test') self.assertEqual(db.pkey('_test1._test_schema'), '_test1') @@ -115,7 +106,7 @@ def test_pkey(self): self.assertRaises(KeyError, db.pkey, '_test_vschema') def test_get(self): - db = opendb() + db = open_db() db.query("INSERT INTO _test_schema VALUES (1234)") db.get('_test_schema', 1234) db.get('_test_schema', 1234, keyname='_test') @@ -123,13 +114,13 @@ def test_get(self): db.get('_test_vschema', 1234, keyname='_test') def test_params(self): - db = opendb() + db = open_db() db.query("INSERT INTO _test_schema VALUES ($1, $2, $3)", 12, None, 34) d = db.get('_test_schema', 12) self.assertEqual(d['dvar'], 34) def test_insert(self): - db = opendb() + db = open_db() d = dict(_test=1234) db.insert('_test_schema', d) self.assertEqual(d['dvar'], 999) @@ -137,7 +128,7 @@ def test_insert(self): self.assertEqual(d['dvar'], 999) def test_context_manager(self): - db = opendb() + db = open_db() t = '_test_schema' d = dict(_test=1235) with db: @@ -163,26 +154,28 @@ def test_context_manager(self): self.assertTrue(db.get(t, 1239)) def test_sqlstate(self): - db = opendb() + db = open_db() db.query("INSERT INTO _test_schema VALUES (1234)") try: db.query("INSERT INTO _test_schema VALUES (1234)") except DatabaseError as error: self.assertTrue(isinstance(error, IntegrityError)) # the SQLSTATE error code for unique violation is 23505 + # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '23505') def test_mixed_case(self): - db = opendb() + db = open_db() try: db.query('CREATE TABLE _test_mc ("_Test" int PRIMARY KEY)') except Error: - db.query("DELETE FROM _test_mc") + db.query("TRUNCATE TABLE _test_mc") d = dict(_Test=1234) - db.insert('_test_mc', d) + r = db.insert('_test_mc', d) + self.assertEqual(r, d) def test_update(self): - db = opendb() + db = open_db() db.query("INSERT INTO _test_schema VALUES (1234)") r = db.get('_test_schema', 1234) @@ -214,7 +207,7 @@ def test_notify(self, options=None): run_as_method = options.get('run_as_method') call_notify = options.get('call_notify') two_payloads = options.get('two_payloads') - db = opendb() + db = open_db() # Get function under test, can be standalone or DB method. fut = db.notification_handler if run_as_method else partial( NotificationHandler, db) @@ -233,7 +226,7 @@ def test_notify(self, options=None): self.assertTrue(target.listening) self.assertTrue(thread.is_alive()) # Open another connection for sending notifications. - db2 = opendb() + db2 = open_db() # Generate notification from the other connection. if two_payloads: db2.begin() @@ -299,7 +292,7 @@ def test_notify_other_options(self): def test_notify_timeout(self): for run_as_method in False, True: - db = opendb() + db = open_db() # Get function under test, can be standalone or DB method. fut = db.notification_handler if run_as_method else partial( NotificationHandler, db) @@ -320,24 +313,4 @@ def test_notify_timeout(self): if __name__ == '__main__': - if len(sys.argv) == 2 and sys.argv[1] == '-l': - print('\n'.join(unittest.getTestCaseNames(UtilityTest, 'test_'))) - sys.exit(0) - - test_list = [name for name in sys.argv[1:] if not name.startswith('-')] - if not test_list: - test_list = unittest.getTestCaseNames(UtilityTest, 'test_') - - suite = unittest.TestSuite() - for test_name in test_list: - try: - suite.addTest(UtilityTest(test_name)) - except Exception: - print("\n ERROR: %s.\n" % sys.exc_value) - sys.exit(1) - - verbosity = '-v' in sys.argv[1:] and 2 or 1 - failfast = '-l' in sys.argv[1:] - runner = unittest.TextTestRunner(verbosity=verbosity, failfast=failfast) - rc = runner.run(suite) - sys.exit(1 if rc.errors or rc.failures else 0) + unittest.main() diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index d36c210a..9ce70684 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -16,7 +16,9 @@ import os from collections import namedtuple + try: + # noinspection PyCompatibility from collections.abc import Iterable except ImportError: # Python < 3.3 from collections import Iterable @@ -35,19 +37,19 @@ dbport = 5432 try: - from .LOCAL_PyGreSQL import * + from .LOCAL_PyGreSQL import * # noqa: F401 except (ImportError, ValueError): try: - from LOCAL_PyGreSQL import * + from LOCAL_PyGreSQL import * # noqa: F401 except ImportError: pass -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences long except NameError: # Python >= 3.0 long = int -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences unicode except NameError: # Python >= 3.0 unicode = str @@ -64,6 +66,7 @@ def connect(): """Create a basic pg connection to the test database.""" + # noinspection PyArgumentList connection = pg.connect(dbname, dbhost, dbport) connection.query("set client_min_messages=warning") return connection @@ -225,6 +228,7 @@ def testAllQueryMembers(self): one onedict onenamed onescalar scalariter scalarresult single singledict singlenamed singlescalar '''.split() + # noinspection PyUnresolvedReferences if pg.get_pqlib_version() < 120000: members.remove('memsize') query_members = [ @@ -467,6 +471,7 @@ def testNamedresultWithGoodFieldnames(self): def testNamedresultWithBadFieldnames(self): r = namedtuple('Bad', ['?'] * 6, rename=True) + # noinspection PyUnresolvedReferences fields = r._fields q = ('select 3 as "0alias", 4 as _alias, 5 as "alias$", 6 as "alias?",' ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one') @@ -647,6 +652,7 @@ def testQuery(self): r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '4') + # noinspection SqlWithoutWhere q = "delete from test_table" r = query(q) self.assertIsInstance(r, str) @@ -684,12 +690,14 @@ def testQueryWithOids(self): r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '4') + # noinspection SqlWithoutWhere q = "delete from test_table" r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '5') def testMemSize(self): + # noinspection PyUnresolvedReferences if pg.get_pqlib_version() < 120000: self.skipTest("pqlib does not support memsize()") query = self.c.query @@ -738,6 +746,7 @@ def testGetresultUtf8(self): v = self.c.query(q).getresult()[0][0] except(pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") + v = None self.assertIsInstance(v, str) self.assertEqual(v, result) q = q.encode('utf8') @@ -755,6 +764,7 @@ def testDictresultUtf8(self): v = self.c.query(q).dictresult()[0]['greeting'] except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") + v = None self.assertIsInstance(v, str) self.assertEqual(v, result) q = q.encode('utf8') @@ -762,7 +772,7 @@ def testDictresultUtf8(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultLatin1(self): + def testGetresultLatin1(self): try: self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): @@ -878,12 +888,12 @@ def tearDown(self): def testQueryWithNoneParam(self): self.assertRaises(TypeError, self.c.query, "select $1", None) self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None) - self.assertEqual(self.c.query("select $1::integer", (None,) - ).getresult(), [(None,)]) - self.assertEqual(self.c.query("select $1::text", [None] - ).getresult(), [(None,)]) - self.assertEqual(self.c.query("select $1::text", [[None]] - ).getresult(), [(None,)]) + self.assertEqual( + self.c.query("select $1::integer", (None,)).getresult(), [(None,)]) + self.assertEqual( + self.c.query("select $1::text", [None]).getresult(), [(None,)]) + self.assertEqual( + self.c.query("select $1::text", [[None]]).getresult(), [(None,)]) def testQueryWithBoolParams(self, bool_enabled=None): query = self.c.query @@ -910,6 +920,7 @@ def testQueryWithBoolParams(self, bool_enabled=None): self.assertEqual(query(q, (True,)).getresult(), r_true) finally: if bool_enabled is not None: + # noinspection PyUnboundLocalVariable pg.set_bool(bool_enabled_default) def testQueryWithBoolParamsNotDefault(self): @@ -974,7 +985,8 @@ def testQueryWithUnicodeParams(self): query = self.c.query try: query('set client_encoding=utf8') - query("select 'wörld'").getresult()[0][0] == 'wörld' + self.assertEqual( + query("select 'wörld'").getresult()[0][0], 'wörld') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") self.assertEqual( @@ -985,7 +997,8 @@ def testQueryWithUnicodeParamsLatin1(self): query = self.c.query try: query('set client_encoding=latin1') - query("select 'wörld'").getresult()[0][0] == 'wörld' + self.assertEqual( + query("select 'wörld'").getresult()[0][0], 'wörld') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() @@ -1015,7 +1028,8 @@ def testQueryWithUnicodeParamsCyrillic(self): query = self.c.query try: query('set client_encoding=iso_8859_5') - query("select 'мир'").getresult()[0][0] == 'мир' + self.assertEqual( + query("select 'мир'").getresult()[0][0], 'мир') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") self.assertRaises( @@ -1034,11 +1048,14 @@ def testQueryWithUnicodeParamsCyrillic(self): def testQueryWithMixedParams(self): self.assertEqual( - self.c.query("select $1+2,$2||', world!'", - (1, 'Hello'),).getresult(), [(3, 'Hello, world!')]) + self.c.query( + "select $1+2,$2||', world!'", (1, 'Hello')).getresult(), + [(3, 'Hello, world!')]) self.assertEqual( - self.c.query("select $1::integer,$2::date,$3::text", - (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')]) + self.c.query( + "select $1::integer,$2::date,$3::text", + (4711, None, 'Hello!')).getresult(), + [(4711, None, 'Hello!')]) def testQueryWithDuplicateParams(self): self.assertRaises( @@ -1090,8 +1107,8 @@ def testUnnamedQueryWithoutParams(self): def testNamedQueryWithoutParams(self): self.assertIsNone(self.c.prepare('hello', "select 'world'")) - self.assertEqual(self.c.query_prepared('hello').getresult(), - [('world',)]) + self.assertEqual( + self.c.query_prepared('hello').getresult(), [('world',)]) def testMultipleNamedQueriesWithoutParams(self): self.assertIsNone(self.c.prepare('query17', "select 17")) @@ -1111,14 +1128,16 @@ def testUnnamedQueryWithParams(self): def testMultipleNamedQueriesWithParams(self): self.assertIsNone(self.c.prepare('q1', "select $1 || '!'")) self.assertIsNone(self.c.prepare('q2', "select $1 || '-' || $2")) - self.assertEqual(self.c.query_prepared('q1', ['hello']).getresult(), + self.assertEqual( + self.c.query_prepared('q1', ['hello']).getresult(), [('hello!',)]) - self.assertEqual(self.c.query_prepared('q2', ['he', 'lo']).getresult(), + self.assertEqual( + self.c.query_prepared('q2', ['he', 'lo']).getresult(), [('he-lo',)]) def testDescribeNonExistentQuery(self): - self.assertRaises(pg.OperationalError, - self.c.describe_prepared, 'does-not-exist') + self.assertRaises( + pg.OperationalError, self.c.describe_prepared, 'does-not-exist') def testDescribeUnnamedQuery(self): self.c.prepare('', "select 1::int, 'a'::char") @@ -1155,9 +1174,11 @@ def assert_proper_cast(self, value, pgtype, pytype): q = 'select $1::%s' % (pgtype,) try: r = self.c.query(q, (value,)).getresult()[0][0] - except pg.ProgrammingError: + except pg.ProgrammingError as e: if pgtype in ('json', 'jsonb'): self.skipTest('database does not support json') + self.fail(str(e)) + # noinspection PyUnboundLocalVariable self.assertIsInstance(r, pytype) if isinstance(value, str): if not value or ' ' in value or '{' in value: @@ -1185,11 +1206,10 @@ def testLong(self): def testFloat(self): self.assert_proper_cast(0, 'float', float) self.assert_proper_cast(0, 'real', float) - self.assert_proper_cast(0, 'double', float) self.assert_proper_cast(0, 'double precision', float) self.assert_proper_cast('infinity', 'float', float) - def testFloat(self): + def testNumeric(self): decimal = pg.get_decimal() self.assert_proper_cast(decimal(0), 'numeric', decimal) self.assert_proper_cast(decimal(0), 'decimal', decimal) @@ -1257,6 +1277,7 @@ def testIterate(self): self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) self.assertEqual(list(r), [(3,), (4,), (5,)]) + # noinspection PyUnresolvedReferences self.assertIsInstance(r[1], tuple) def testIterateTwice(self): @@ -1578,7 +1599,7 @@ def testSingleScalarWithSingleRow(self): self.assertIsInstance(r, int) self.assertEqual(r, 1) - def testSingleWithTwoRows(self): + def testSingleScalarWithTwoRows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.singlescalar() @@ -1650,16 +1671,17 @@ def tearDown(self): data = [ (-1, -1, long(-1), True, '1492-10-12', '08:30:00', - -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), + -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), (0, 0, long(0), False, '1607-04-14', '09:00:00', - 0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'), + 0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'), (1, 1, long(1), True, '1801-03-04', '03:45:00', - 1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'), + 1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'), (2, 2, long(2), False, '1903-12-17', '11:22:00', - 2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')] + 2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')] @classmethod def db_len(cls, s, encoding): + # noinspection PyUnresolvedReferences if cls.has_encoding: s = s if isinstance(s, unicode) else s.decode(encoding) else: @@ -1770,9 +1792,9 @@ def testInserttableNullValues(self): def testInserttableMaxValues(self): data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), - True, '2999-12-31', '11:59:59', 1e99, - 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, - "1", "1234", "1234", "1234" * 100)] + True, '2999-12-31', '11:59:59', 1e99, + 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, + "1", "1234", "1234", "1234" * 100)] self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) @@ -1948,6 +1970,7 @@ def testGetline(self): for i in range(n + 2): v = getline() if i < n: + # noinspection PyStringFormat self.assertEqual(v, '%d\t%s' % data[i]) elif i == n: self.assertEqual(v, '\\.') @@ -2040,7 +2063,7 @@ def testSetNoticeReceiver(self): self.assertIsNone(self.c.set_notice_receiver(None)) def testSetAndGetNoticeReceiver(self): - r = lambda notice: None + r = lambda notice: None # noqa: E731 self.assertIsNone(self.c.set_notice_receiver(r)) self.assertIs(self.c.get_notice_receiver(), r) self.assertIsNone(self.c.set_notice_receiver(None)) @@ -2268,6 +2291,7 @@ def testSetDecimal(self): r = query("select 3425::numeric") except pg.DatabaseError: self.skipTest('database does not support numeric') + r = None r = r.getresult()[0][0] self.assertIsInstance(r, decimal_class) self.assertEqual(r, decimal_class('3425')) @@ -2325,6 +2349,7 @@ def testSetBool(self): r = query("select true::bool") except pg.ProgrammingError: self.skipTest('database does not support bool') + r = None r = r.getresult()[0][0] self.assertIsInstance(r, bool) self.assertEqual(r, True) @@ -2387,6 +2412,7 @@ def testSetByteaEscaped(self): r = query("select 'data'::bytea") except pg.ProgrammingError: self.skipTest('database does not support bytea') + r = None r = r.getresult()[0][0] self.assertIsInstance(r, bytes) self.assertEqual(r, b'data') diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index d9362889..5e0cd73b 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -36,24 +36,24 @@ debug = False # let DB wrapper print debugging output try: - from .LOCAL_PyGreSQL import * + from .LOCAL_PyGreSQL import * # noqa: F401 except (ImportError, ValueError): try: - from LOCAL_PyGreSQL import * + from LOCAL_PyGreSQL import * # noqa: F401 except ImportError: pass -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences long except NameError: # Python >= 3.0 long = int -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences unicode except NameError: # Python >= 3.0 unicode = str -if str is bytes: # noinspection PyUnresolvedReferences +if str is bytes: # noinspection PyCompatibility,PyUnresolvedReferences from StringIO import StringIO else: # Python >= 3.0 from io import StringIO @@ -164,6 +164,7 @@ class TestDBClassInit(unittest.TestCase): def testBadParams(self): self.assertRaises(TypeError, pg.DB, invalid=True) + # noinspection PyUnboundLocalVariable def testDeleteDb(self): db = DB() del db.db @@ -341,6 +342,7 @@ def testMethodQueryDataError(self): try: self.db.query("select 1/0") except pg.DataError as error: + # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') def testMethodEndcopy(self): @@ -821,7 +823,7 @@ def testReset(self): default_datestyle = db.get_parameter('datestyle') changed_datestyle = 'ISO, DMY' if changed_datestyle == default_datestyle: - changed_datestyle == 'ISO, YMD' + changed_datestyle = 'ISO, YMD' self.db.set_parameter('datestyle', changed_datestyle) r = self.db.get_parameter('datestyle') self.assertEqual(r, changed_datestyle) @@ -842,7 +844,7 @@ def testReopen(self): default_datestyle = db.get_parameter('datestyle') changed_datestyle = 'ISO, DMY' if changed_datestyle == default_datestyle: - changed_datestyle == 'ISO, YMD' + changed_datestyle = 'ISO, YMD' self.db.set_parameter('datestyle', changed_datestyle) r = self.db.get_parameter('datestyle') self.assertEqual(r, changed_datestyle) @@ -906,6 +908,7 @@ def testQuery(self): r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '4') + # noinspection SqlWithoutWhere q = "delete from test_table" r = query(q) self.assertIsInstance(r, str) @@ -939,6 +942,7 @@ def testQueryWithOids(self): r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '4') + # noinspection SqlWithoutWhere q = "delete from test_table" r = query(q) self.assertIsInstance(r, str) @@ -982,6 +986,7 @@ def testQueryDataError(self): try: self.db.query("select 1/0") except pg.DataError as error: + # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') def testQueryFormatted(self): @@ -1045,8 +1050,8 @@ def testPrepare(self): p = self.db.prepare self.assertIsNone(p('my query', "select 'hello'")) self.assertIsNone(p('my other query', "select 'world'")) - self.assertRaises(pg.ProgrammingError, - p, 'my query', "select 'hello, too'") + self.assertRaises( + pg.ProgrammingError, p, 'my query', "select 'hello, too'") def testPrepareUnnamed(self): p = self.db.prepare @@ -1148,19 +1153,21 @@ def testPkey(self): for t in ('pkeytest', 'primary key test'): self.createTable('%s0' % t, 'a smallint') self.createTable('%s1' % t, 'b smallint primary key') - self.createTable('%s2' % t, - 'c smallint, d smallint primary key') - self.createTable('%s3' % t, + self.createTable('%s2' % t, 'c smallint, d smallint primary key') + self.createTable( + '%s3' % t, 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (f, h)') - self.createTable('%s4' % t, + self.createTable( + '%s4' % t, 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (h, f)') - self.createTable('%s5' % t, - 'more_than_one_letter varchar primary key') - self.createTable('%s6' % t, - '"with space" date primary key') - self.createTable('%s7' % t, + self.createTable( + '%s5' % t, 'more_than_one_letter varchar primary key') + self.createTable( + '%s6' % t, '"with space" date primary key') + self.createTable( + '%s7' % t, 'a_very_long_column_name varchar, "with space" date, "42" int,' ' primary key (a_very_long_column_name, "with space", "42")') self.assertRaises(KeyError, pkey, '%s0' % t) @@ -1309,7 +1316,8 @@ def testGetAttnames(self): def testGetAttnamesWithQuotes(self): get_attnames = self.db.get_attnames table = 'test table for get_attnames()' - self.createTable(table, + self.createTable( + table, '"Prime!" smallint, "much space" integer, "Questions?" text') r = get_attnames(table) self.assertIsInstance(r, dict) @@ -1339,16 +1347,18 @@ def testGetAttnamesWithQuotes(self): 't': 'text', 'v': 'character varying', 'y': 'smallint', 'x': 'smallint', 'z': 'smallint'}) else: - self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int', - 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money', - 'normal_name': 'int', 'Special Name': 'int', - 'u': 'text', 't': 'text', 'v': 'text', - 'y': 'int', 'x': 'int', 'z': 'int'}) + self.assertEqual(r, { + 'a': 'int', 'b': 'int', 'c': 'int', + 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money', + 'normal_name': 'int', 'Special Name': 'int', + 'u': 'text', 't': 'text', 'v': 'text', + 'y': 'int', 'x': 'int', 'z': 'int'}) def testGetAttnamesWithRegtypes(self): get_attnames = self.db.get_attnames - self.createTable('test_table', 'n int, alpha smallint, beta bool,' - ' gamma char(5), tau text, v varchar(3)') + self.createTable( + 'test_table', 'n int, alpha smallint, beta bool,' + ' gamma char(5), tau text, v varchar(3)') use_regtypes = self.db.use_regtypes regtypes = use_regtypes() self.assertEqual(regtypes, self.regtypes) @@ -1364,8 +1374,9 @@ def testGetAttnamesWithRegtypes(self): def testGetAttnamesWithoutRegtypes(self): get_attnames = self.db.get_attnames - self.createTable('test_table', 'n int, alpha smallint, beta bool,' - ' gamma char(5), tau text, v varchar(3)') + self.createTable( + 'test_table', 'n int, alpha smallint, beta bool,' + ' gamma char(5), tau text, v varchar(3)') use_regtypes = self.db.use_regtypes regtypes = use_regtypes() self.assertEqual(regtypes, self.regtypes) @@ -1424,8 +1435,9 @@ def testGetAttnamesIsOrdered(self): r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' - self.createTable(table, 'n int, alpha smallint, v varchar(3),' - ' gamma char(5), tau text, beta bool') + self.createTable( + table, 'n int, alpha smallint, v varchar(3),' + ' gamma char(5), tau text, beta bool') r = get_attnames(table) self.assertIsInstance(r, OrderedDict) if self.regtypes: @@ -1464,7 +1476,7 @@ def testGetAttnamesIsAttrDict(self): table = 'test table for get_attnames' self.createTable( table, 'n int, alpha smallint, v varchar(3),' - ' gamma char(5), tau text, beta bool') + ' gamma char(5), tau text, beta bool') r = get_attnames(table) self.assertIsInstance(r, AttrDict) if self.regtypes: @@ -1623,10 +1635,10 @@ def testGetWithOids(self): def testGetWithCompositeKey(self): get = self.db.get - query = self.db.query table = 'get_test_table_1' - self.createTable(table, 'n integer primary key, t text', - values=enumerate('abc', start=1)) + self.createTable( + table, 'n integer primary key, t text', + values=enumerate('abc', start=1)) self.assertEqual(get(table, 2)['t'], 'b') self.assertEqual(get(table, 1, 'n')['t'], 'a') self.assertEqual(get(table, 2, ('n',))['t'], 'b') @@ -1636,10 +1648,10 @@ def testGetWithCompositeKey(self): self.assertEqual(get(table, ('a',), ('t',))['n'], 1) self.assertEqual(get(table, ['c'], ['t'])['n'], 3) table = 'get_test_table_2' - self.createTable(table, - 'n integer, m integer, t text, primary key (n, m)', - values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) - for n in range(3) for m in range(2)]) + self.createTable( + table, 'n integer, m integer, t text, primary key (n, m)', + values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) + for n in range(3) for m in range(2)]) self.assertRaises(KeyError, get, table, 2) self.assertEqual(get(table, (1, 1))['t'], 'a') self.assertEqual(get(table, (1, 2))['t'], 'b') @@ -1655,11 +1667,11 @@ def testGetWithCompositeKey(self): def testGetWithQuotedNames(self): get = self.db.get - query = self.db.query table = 'test table for get()' - self.createTable(table, '"Prime!" smallint primary key,' - ' "much space" integer, "Questions?" text', - values=[(17, 1001, 'No!')]) + self.createTable( + table, '"Prime!" smallint primary key,' + ' "much space" integer, "Questions?" text', + values=[(17, 1001, 'No!')]) r = get(table, 17) self.assertIsInstance(r, dict) self.assertEqual(r['Prime!'], 17) @@ -1694,13 +1706,15 @@ def testGetLittleBobbyTables(self): try: get('test_students', "D' Arcy") except pg.DatabaseError as error: - self.assertEqual(str(error), + self.assertEqual( + str(error), 'No such record in test_students\nwhere "firstname" = $1\n' 'with $1="D\' Arcy"') try: get('test_students', "Robert'); TRUNCATE TABLE test_students;--") except pg.DatabaseError as error: - self.assertEqual(str(error), + self.assertEqual( + str(error), 'No such record in test_students\nwhere "firstname" = $1\n' 'with $1="Robert\'); TRUNCATE TABLE test_students;--"') q = "select * from test_students order by 1 limit 4" @@ -1716,52 +1730,53 @@ def testInsert(self): table = 'insert_test_table' self.createTable( table, 'i2 smallint, i4 integer, i8 bigint,' - ' d numeric, f4 real, f8 double precision, m money,' - ' v4 varchar(4), c4 char(4), t text,' - ' b boolean, ts timestamp') - tests = [dict(i2=None, i4=None, i8=None), - (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)), - (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)), - dict(i2=42, i4=123456, i8=9876543210), - dict(i2=2 ** 15 - 1, - i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)), - dict(d=None), (dict(d=''), dict(d=None)), - dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))), - dict(f4=None, f8=None), dict(f4=0, f8=0), - (dict(f4='', f8=''), dict(f4=None, f8=None)), - (dict(d=1234.5, f4=1234.5, f8=1234.5), - dict(d=Decimal('1234.5'))), - dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875), - dict(d=Decimal('123456789.9876543212345678987654321')), - dict(m=None), (dict(m=''), dict(m=None)), - dict(m=Decimal('-1234.56')), - (dict(m='-1234.56'), dict(m=Decimal('-1234.56'))), - dict(m=Decimal('1234.56')), dict(m=Decimal('123456')), - (dict(m='1234.56'), dict(m=Decimal('1234.56'))), - (dict(m=1234.5), dict(m=Decimal('1234.5'))), - (dict(m=-1234.5), dict(m=Decimal('-1234.5'))), - (dict(m=123456), dict(m=Decimal('123456'))), - (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))), - dict(b=None), (dict(b=''), dict(b=None)), - dict(b='f'), dict(b='t'), - (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')), - (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')), - (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')), - (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')), - (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')), - (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')), - dict(v4=None, c4=None, t=None), - (dict(v4='', c4='', t=''), dict(c4=' ' * 4)), - dict(v4='1234', c4='1234', t='1234' * 10), - dict(v4='abcd', c4='abcd', t='abcdefg'), - (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')), - dict(ts=None), (dict(ts=''), dict(ts=None)), - (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)), - dict(ts='2012-12-21 00:00:00'), - (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')), - dict(ts='2012-12-21 12:21:12'), - dict(ts='2013-01-05 12:13:14'), - dict(ts='current_timestamp')] + ' d numeric, f4 real, f8 double precision, m money,' + ' v4 varchar(4), c4 char(4), t text,' + ' b boolean, ts timestamp') + tests = [ + dict(i2=None, i4=None, i8=None), + (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)), + (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)), + dict(i2=42, i4=123456, i8=9876543210), + dict(i2=2 ** 15 - 1, + i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)), + dict(d=None), (dict(d=''), dict(d=None)), + dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))), + dict(f4=None, f8=None), dict(f4=0, f8=0), + (dict(f4='', f8=''), dict(f4=None, f8=None)), + (dict(d=1234.5, f4=1234.5, f8=1234.5), + dict(d=Decimal('1234.5'))), + dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875), + dict(d=Decimal('123456789.9876543212345678987654321')), + dict(m=None), (dict(m=''), dict(m=None)), + dict(m=Decimal('-1234.56')), + (dict(m='-1234.56'), dict(m=Decimal('-1234.56'))), + dict(m=Decimal('1234.56')), dict(m=Decimal('123456')), + (dict(m='1234.56'), dict(m=Decimal('1234.56'))), + (dict(m=1234.5), dict(m=Decimal('1234.5'))), + (dict(m=-1234.5), dict(m=Decimal('-1234.5'))), + (dict(m=123456), dict(m=Decimal('123456'))), + (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))), + dict(b=None), (dict(b=''), dict(b=None)), + dict(b='f'), dict(b='t'), + (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')), + (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')), + (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')), + (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')), + (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')), + (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')), + dict(v4=None, c4=None, t=None), + (dict(v4='', c4='', t=''), dict(c4=' ' * 4)), + dict(v4='1234', c4='1234', t='1234' * 10), + dict(v4='abcd', c4='abcd', t='abcdefg'), + (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')), + dict(ts=None), (dict(ts=''), dict(ts=None)), + (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)), + dict(ts='2012-12-21 00:00:00'), + (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')), + dict(ts='2012-12-21 12:21:12'), + dict(ts='2013-01-05 12:13:14'), + dict(ts='current_timestamp')] for test in tests: if isinstance(test, dict): data = test @@ -1798,7 +1813,7 @@ def testInsert(self): data = query('select * from "%s"' % table).dictresult()[0] data = dict(item for item in data.items() if item[0] in expect) self.assertEqual(data, expect) - query('delete from "%s"' % table) + query('truncate table "%s"' % table) def testInsertWithOids(self): if not self.oids: @@ -1853,7 +1868,7 @@ def testInsertWithOids(self): q = 'select n from test_table order by 1 limit 9' r = ' '.join(str(row[0]) for row in query(q).getresult()) self.assertEqual(r, '1 2 3 3 3 4 5 6') - query("truncate test_table") + query("truncate table test_table") query("alter table test_table add unique (n)") r = insert('test_table', dict(n=7)) self.assertIsInstance(r, dict) @@ -1892,7 +1907,7 @@ def testInsertWithQuotedNames(self): def testInsertIntoView(self): insert = self.db.insert query = self.db.query - query("truncate test") + query("truncate table test") q = 'select * from test_view order by i4 limit 3' r = query(q).getresult() self.assertEqual(r, []) @@ -2354,8 +2369,8 @@ def testClear(self): i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='') self.assertEqual(r, result) table = 'clear_test_table' - self.createTable(table, - 'n integer, f float, b boolean, d date, t text') + self.createTable( + table, 'n integer, f float, b boolean, d date, t text') r = clear(table) result = dict(n=0, f=0, b=f, d='', t='') self.assertEqual(r, result) @@ -2370,8 +2385,9 @@ def testClear(self): def testClearWithQuotedNames(self): clear = self.db.clear table = 'test table for clear()' - self.createTable(table, '"Prime!" smallint primary key,' - ' "much space" integer, "Questions?" text') + self.createTable( + table, '"Prime!" smallint primary key,' + ' "much space" integer, "Questions?" text') r = clear(table) self.assertIsInstance(r, dict) self.assertEqual(r['Prime!'], 0) @@ -2527,15 +2543,15 @@ def testDeleteWithCompositeKey(self): self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b')) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1) r = [r[0] for r in query('select t from "%s" where n=2' - ' order by m' % table).getresult()] + ' order by m' % table).getresult()] self.assertEqual(r, ['c']) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0) r = [r[0] for r in query('select t from "%s" where n=3' - ' order by m' % table).getresult()] + ' order by m' % table).getresult()] self.assertEqual(r, ['e', 'f']) self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1) r = [r[0] for r in query('select t from "%s" where n=3' - ' order by m' % table).getresult()] + ' order by m' % table).getresult()] self.assertEqual(r, ['f']) def testDeleteWithQuotedNames(self): @@ -2544,7 +2560,7 @@ def testDeleteWithQuotedNames(self): table = 'test table for delete()' self.createTable( table, '"Prime!" smallint primary key,' - ' "much space" integer, "Questions?" text', + ' "much space" integer, "Questions?" text', values=[(19, 5005, 'Yes!')]) r = {'Prime!': 17} r = delete(table, r) @@ -2786,6 +2802,7 @@ def testTruncateQuoted(self): r = query(q).getresult()[0][0] self.assertEqual(r, 0) + # noinspection PyUnresolvedReferences def testGetAsList(self): get_as_list = self.db.get_as_list self.assertRaises(TypeError, get_as_list) @@ -2898,6 +2915,7 @@ def testGetAsList(self): else: self.assertEqual(t, ('bart',)) + # noinspection PyUnresolvedReferences def testGetAsDict(self): get_as_dict = self.db.get_as_dict self.assertRaises(TypeError, get_as_dict) @@ -2911,8 +2929,8 @@ def testGetAsDict(self): named = hasattr(r, 'colname') colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'), (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')] - self.createTable(table, - 'id smallint primary key, rgb char(7), name varchar', + self.createTable( + table, 'id smallint primary key, rgb char(7), name varchar', values=colors) # keyname must be string, list or tuple self.assertRaises(KeyError, get_as_dict, table, 3) @@ -2946,6 +2964,7 @@ def testGetAsDict(self): self.assertIn(key, expected) row = r[key] self.assertIsInstance(row, tuple) + # noinspection PyTypeChecker t = expected[key] self.assertEqual(row, t) if named: @@ -2967,6 +2986,7 @@ def testGetAsDict(self): row = r[key] self.assertIsInstance(row, tuple) self.assertIsInstance(row[0], str) + # noinspection PyTypeChecker t = expected[key] self.assertEqual(row, t) if named: @@ -2981,23 +3001,26 @@ def testGetAsDict(self): self.assertIsInstance(key, tuple) row = r[key] self.assertIsInstance(row, str) + # noinspection PyTypeChecker t = expected[key] self.assertEqual(row, t) self.assertEqual(r.keys(), expected.keys()) - r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True) + r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], + scalar=True) self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[1], row[2]) - for row in sorted(colors, key=itemgetter(1))) + expected = OrderedDict( + (row[1], row[2]) for row in sorted(colors, key=itemgetter(1))) self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, str) row = r[key] self.assertIsInstance(row, str) + # noinspection PyTypeChecker t = expected[key] self.assertEqual(row, t) self.assertEqual(r.keys(), expected.keys()) - r = get_as_dict(table, what='id, name', - where="rgb like '#b%'", scalar=True) + r = get_as_dict( + table, what='id, name', where="rgb like '#b%'", scalar=True) self.assertIsInstance(r, OrderedDict) expected = OrderedDict((row[0], row[2]) for row in colors[1:3]) self.assertEqual(r, expected) @@ -3009,13 +3032,15 @@ def testGetAsDict(self): self.assertEqual(row, t) self.assertEqual(r.keys(), expected.keys()) expected = r - r = get_as_dict(table, what=['name', 'id'], + r = get_as_dict( + table, what=['name', 'id'], where=['id > 1', 'id < 4', "rgb like '#b%'", "name not like 'A%'", "name not like '%t'"], scalar=True) self.assertEqual(r, expected) r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True) self.assertEqual(r, expected) - r = get_as_dict(table, keyname=('id',), what=('name', 'id'), + r = get_as_dict( + table, keyname=('id',), what=('name', 'id'), where=('id > 1', 'id < 4'), order=('id',), scalar=True) self.assertEqual(r, expected) r = get_as_dict(table, limit=1) @@ -3481,7 +3506,7 @@ def testArrayOfIds(self): r = self.db.get_attnames('arraytest') if self.regtypes: self.assertEqual(r, dict( - i='integer', c='cid[]', o='oid[]', x='xid[]')) + i='integer', c='cid[]', o='oid[]', x='xid[]')) else: self.assertEqual(r, dict( i='int', c='int[]', o='int[]', x='int[]')) @@ -3521,6 +3546,7 @@ def testArrayOfText(self): self.assertIsInstance(r['data'][1], str) self.assertIsNone(r['data'][2]) + # noinspection PyUnresolvedReferences def testArrayOfBytea(self): array_on = pg.get_array() bytea_escaped = pg.get_bytea_escaped() @@ -3646,6 +3672,7 @@ def testArrayOfJsonb(self): else: self.assertEqual(r, '{NULL,NULL}') + # noinspection PyUnresolvedReferences def testDeepArray(self): array_on = pg.get_array() self.createTable( @@ -3666,6 +3693,7 @@ def testDeepArray(self): else: self.assertTrue(r['data'].startswith('{{{"Hello,')) + # noinspection PyUnresolvedReferences def testInsertUpdateGetRecord(self): query = self.db.query query('create type test_person_type as' @@ -3685,13 +3713,13 @@ def testInsertUpdateGetRecord(self): else: self.assertEqual(person_typ, 'record') if self.regtypes: - self.assertEqual(person_typ.attnames, - dict(name='character varying', age='smallint', - married='boolean', weight='real', salary='money')) + self.assertEqual(person_typ.attnames, dict( + name='character varying', age='smallint', + married='boolean', weight='real', salary='money')) else: - self.assertEqual(person_typ.attnames, - dict(name='text', age='int', married='bool', - weight='float', salary='money')) + self.assertEqual(person_typ.attnames, dict( + name='text', age='int', married='bool', + weight='float', salary='money')) decimal = pg.get_decimal() if pg.get_bool(): bool_class = bool @@ -3764,6 +3792,7 @@ def testInsertUpdateGetRecord(self): self.assertEqual(r['id'], 3) self.assertIsNone(r['person']) + # noinspection PyUnresolvedReferences def testRecordInsertBytea(self): query = self.db.query query('create type test_person_type as' @@ -3804,6 +3833,7 @@ def testRecordInsertJson(self): p = r['person'] self.assertIsInstance(p, tuple) if pg.get_jsondecode() is None: + # noinspection PyUnresolvedReferences p = p._replace(data=json.loads(p.data)) self.assertEqual(p, person) self.assertEqual(p.name, 'John Doe') @@ -3811,6 +3841,7 @@ def testRecordInsertJson(self): self.assertEqual(p.data, person[1]) self.assertIsInstance(p.data, dict) + # noinspection PyUnresolvedReferences def testRecordLiteral(self): query = self.db.query query('create type test_person_type as' @@ -3899,7 +3930,7 @@ def testTimetz(self): def testTimestamp(self): query = self.db.query for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', - 'SQL, MDY', 'SQL, DMY', 'German'): + 'SQL, MDY', 'SQL, DMY', 'German'): self.db.set_parameter('datestyle', datestyle) d = datetime(2016, 3, 14) q = "select $1::timestamp" @@ -3917,7 +3948,7 @@ def testTimestamp(self): self.assertIsInstance(r, datetime) self.assertEqual(r, d) q = ("select '10000-08-01 AD'::timestamp," - " '0099-01-08 BC'::timestamp") + " '0099-01-08 BC'::timestamp") r = query(q).getresult()[0] self.assertIsInstance(r[0], datetime) self.assertIsInstance(r[1], datetime) @@ -3941,7 +3972,7 @@ def testTimestamptz(self): tzinfo = pg._get_timezone(tz) self.db.set_parameter('timezone', timezone) for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', - 'SQL, MDY', 'SQL, DMY', 'German'): + 'SQL, MDY', 'SQL, DMY', 'German'): self.db.set_parameter('datestyle', datestyle) d = datetime(2016, 3, 14, tzinfo=tzinfo) q = "select $1::timestamptz" @@ -3959,7 +3990,7 @@ def testTimestamptz(self): self.assertIsInstance(r, datetime) self.assertEqual(r, d) q = ("select '10000-08-01 AD'::timestamptz," - " '0099-01-08 BC'::timestamptz") + " '0099-01-08 BC'::timestamptz") r = query(q).getresult()[0] self.assertIsInstance(r[0], datetime) self.assertIsInstance(r[1], datetime) @@ -4055,6 +4086,7 @@ def testDbTypesInfo(self): self.assertNotEqual(typ.relid, 0) attnames = typ.attnames self.assertIsInstance(attnames, dict) + # noinspection PyUnresolvedReferences self.assertIs(attnames, dbtypes.get_attnames('pg_type')) self.assertIn('typname', attnames) typname = attnames['typname'] @@ -4067,6 +4099,7 @@ def testDbTypesInfo(self): self.assertEqual(typlen.typtype, 'b') # base self.assertEqual(typlen.category, 'N') # numeric + # noinspection PyUnresolvedReferences def testDbTypesTypecast(self): dbtypes = self.db.dbtypes self.assertIsInstance(dbtypes, dict) @@ -4082,7 +4115,7 @@ def testDbTypesTypecast(self): self.assertIs(dbtypes.get_typecast('int4'), int) self.assertNotIn('circle', dbtypes) self.assertIsNone(dbtypes.get_typecast('circle')) - squared_circle = lambda v: 'Squared Circle: %s' % v + squared_circle = lambda v: 'Squared Circle: %s' % v # noqa: E731 dbtypes.set_typecast('circle', squared_circle) self.assertIs(dbtypes.get_typecast('circle'), squared_circle) r = self.db.query("select '0,0,1'::circle").getresult()[0][0] @@ -4107,7 +4140,7 @@ def testGetSetTypeCast(self): self.assertIs(get_typecast('bool'), pg.cast_bool) cast_circle = get_typecast('circle') self.addCleanup(set_typecast, 'circle', cast_circle) - squared_circle = lambda v: 'Squared Circle: %s' % v + squared_circle = lambda v: 'Squared Circle: %s' % v # noqa: E731 self.assertNotIn('circle', dbtypes) set_typecast('circle', squared_circle) self.assertNotIn('circle', dbtypes) @@ -4121,7 +4154,7 @@ def testGetSetTypeCast(self): def testNotificationHandler(self): # the notification handler itself is tested separately f = self.db.notification_handler - callback = lambda arg_dict: None + callback = lambda arg_dict: None # noqa: E731 handler = f('test', callback) self.assertIsInstance(handler, pg.NotificationHandler) self.assertIs(handler.db, self.db) @@ -4226,11 +4259,13 @@ def tearDownClass(cls): @classmethod def set_option(cls, option, value): + # noinspection PyUnresolvedReferences cls.saved_options[option] = getattr(pg, 'get_' + option)() return getattr(pg, 'set_' + option)(value) @classmethod def reset_option(cls, option): + # noinspection PyUnresolvedReferences return getattr(pg, 'set_' + option)(cls.saved_options[option]) @@ -4265,15 +4300,14 @@ def testGuessSimpleType(self): self.assertEqual(f([[[False]]]), 'bool[]') r = f(('string', True, 3, 2.75, [1], [False])) self.assertEqual(r, 'record') - self.assertEqual(list(r.attnames.values()), - ['text', 'bool', 'int', 'float', 'int[]', 'bool[]']) + self.assertEqual(list(r.attnames.values()), [ + 'text', 'bool', 'int', 'float', 'int[]', 'bool[]']) def testAdaptQueryTypedList(self): format_query = self.adapter.format_query - self.assertRaises(TypeError, format_query, - '%s,%s', (1, 2), ('int2',)) - self.assertRaises(TypeError, format_query, - '%s,%s', (1,), ('int2', 'int2')) + self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), ('int2',)) + self.assertRaises( + TypeError, format_query, '%s,%s', (1,), ('int2', 'int2')) values = (3, 7.5, 'hello', True) types = ('int4', 'float4', 'text', 'bool') sql, params = format_query("select %s,%s,%s,%s", values, types) @@ -4311,7 +4345,8 @@ def testAdaptQueryTypedList(self): def testAdaptQueryTypedDict(self): format_query = self.adapter.format_query - self.assertRaises(TypeError, format_query, + self.assertRaises( + TypeError, format_query, '%s,%s', dict(i1=1, i2=2), dict(i1='int2')) values = dict(i=3, f=7.5, t='hello', b=True) types = dict(i='int4', f='float4', t='text', b='bool') @@ -4367,7 +4402,7 @@ def testAdaptQueryUntypedList(self): self.assertEqual(sql, "$1,$2,$3") self.assertEqual(params, ['{1,2,3}', '{a,b,c}', '{t,f,t}']) values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']], - [[True, False], [False, True]]) + [[True, False], [False, True]]) sql, params = format_query("%s,%s,%s", values) self.assertEqual(sql, "$1,$2,$3") self.assertEqual(params, [ @@ -4392,7 +4427,8 @@ def testAdaptQueryUntypedDict(self): sql, params = format_query("%(i)s,%(t)s,%(b)s", values) self.assertEqual(sql, "$2,$3,$1") self.assertEqual(params, ['{t,f,t}', '{1,2,3}', '{a,b,c}']) - values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']], + values = dict( + i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']], b=[[True, False], [False, True]]) sql, params = format_query("%(i)s,%(t)s,%(b)s", values) self.assertEqual(sql, "$2,$3,$1") @@ -4415,19 +4451,20 @@ def testAdaptQueryInlineList(self): self.assertEqual(params, []) values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True]) sql, params = format_query("%s,%s,%s", values, inline=True) - self.assertEqual(sql, - "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]") + self.assertEqual( + sql, "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]") self.assertEqual(params, []) values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']], - [[True, False], [False, True]]) + [[True, False], [False, True]]) sql, params = format_query("%s,%s,%s", values, inline=True) - self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']]," - "ARRAY[[true,false],[false,true]]") + self.assertEqual( + sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']]," + "ARRAY[[true,false],[false,true]]") self.assertEqual(params, []) values = [(3, 7.5, 'hello', True, [123], ['abc'])] sql, params = format_query('select %s', values, inline=True) - self.assertEqual(sql, - "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") + self.assertEqual( + sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) def testAdaptQueryInlineDict(self): @@ -4444,28 +4481,32 @@ def testAdaptQueryInlineDict(self): self.assertEqual(params, []) values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True]) sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True) - self.assertEqual(sql, - "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]") + self.assertEqual( + sql, "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]") self.assertEqual(params, []) - values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']], + values = dict( + i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']], b=[[True, False], [False, True]]) sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True) - self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']]," - "ARRAY[[true,false],[false,true]]") + self.assertEqual( + sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']]," + "ARRAY[[true,false],[false,true]]") self.assertEqual(params, []) values = dict(record=(3, 7.5, 'hello', True, [123], ['abc'])) sql, params = format_query('select %(record)s', values, inline=True) - self.assertEqual(sql, - "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") + self.assertEqual( + sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) def testAdaptQueryWithPgRepr(self): format_query = self.adapter.format_query - self.assertRaises(TypeError, format_query, - '%s', object(), inline=True) + self.assertRaises(TypeError, format_query, '%s', object(), inline=True) + class TestObject: + # noinspection PyMethodMayBeStatic def __pg_repr__(self): return "'adapted'" + sql, params = format_query('select %s', [TestObject()], inline=True) self.assertEqual(sql, "select 'adapted'") self.assertEqual(params, []) @@ -4491,7 +4532,8 @@ def setUpClass(cls): try: query("create schema %s" % (schema,)) except pg.ProgrammingError: - raise RuntimeError("The test user cannot create schemas.\n" + raise RuntimeError( + "The test user cannot create schemas.\n" "Grant create on database %s to the user" " for running these tests." % dbname) else: @@ -4535,7 +4577,7 @@ def testGetTables(self): else: schema = "public" for t in (schema + ".t", - schema + ".t" + str(num_schema)): + schema + ".t" + str(num_schema)): self.assertIn(t, tables) def testGetAttnames(self): @@ -4650,7 +4692,8 @@ def testDebugIsTrue(self): def testDebugIsString(self): self.db.debug = "Test with string: %s." self.send_queries() - self.assertEqual(self.get_output(), + self.assertEqual( + self.get_output(), "Test with string: select 1.\nTest with string: select 2.\n") def testDebugIsFileLike(self): @@ -4703,12 +4746,14 @@ def fut(): db = DB() db.query("select $1::int as r", 42).dictresult() db.close() + self.getLeaks(fut) def testLeaksWithoutClose(self): def fut(): db = DB() db.query("select $1::int as r", 42).dictresult() + self.getLeaks(fut) diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index d9735437..101416b8 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -19,12 +19,12 @@ from datetime import timedelta -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences long except NameError: # Python >= 3.0 long = int -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences unicode except NameError: # Python >= 3.0 unicode = str @@ -122,6 +122,7 @@ def testDefBase(self): self.assertEqual(pg.get_defbase(), d0) def testPqlibVersion(self): + # noinspection PyUnresolvedReferences v = pg.get_pqlib_version() self.assertIsInstance(v, long) self.assertGreater(v, 90000) @@ -281,6 +282,7 @@ def testParserNested(self): for i in range(7): self.assertIsInstance(r, list) self.assertEqual(len(r), 1) + # noinspection PyUnresolvedReferences r = r[0] self.assertEqual(r, 'abc') @@ -929,7 +931,7 @@ class TestConfigFunctions(unittest.TestCase): def testGetDatestyle(self): self.assertIsNone(pg.get_datestyle()) - def testGetDatestyle(self): + def testSetDatestyle(self): datestyle = pg.get_datestyle() try: pg.set_datestyle('ISO, YMD') diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index c826cb86..b82d56fa 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -23,16 +23,17 @@ dbport = 5432 try: - from .LOCAL_PyGreSQL import * + from .LOCAL_PyGreSQL import * # noqa: F401 except (ImportError, ValueError): try: - from LOCAL_PyGreSQL import * + from LOCAL_PyGreSQL import * # noqa: F401 except ImportError: pass windows = os.name == 'nt' +# noinspection PyArgumentList def connect(): """Create a basic pg connection to the test database.""" connection = pg.connect(dbname, dbhost, dbport) @@ -208,6 +209,7 @@ def testRepr(self): self.assertTrue(r.startswith('= 3.0 long = int @@ -43,20 +43,19 @@ def __init__(self, value): self.value = value def __pg_repr__(self): - return "B'{0:b}'".format(self.value) + return "B'{0:b}'".format(self.value) class test_PyGreSQL(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () - connect_kw_args = {'database': dbname, - 'host': '%s:%d' % (dbhost or '', dbport or -1)} + connect_kw_args = { + 'database': dbname, 'host': '%s:%d' % (dbhost or '', dbport or -1)} lower_func = 'lower' # For stored procedure test def setUp(self): - # Call superclass setUp in case this does something in the future dbapi20.DatabaseAPI20Test.setUp(self) try: con = self._connect() @@ -65,7 +64,7 @@ def setUp(self): import pg try: # first try to log in as superuser db = pg.DB('postgres', dbhost or None, dbport or -1, - user='postgres') + user='postgres') except Exception: # then try to log in as current user db = pg.DB('postgres', dbhost or None, dbport or -1) db.query('create database ' + dbname) @@ -85,7 +84,7 @@ def test_connect_kwargs(self): con = self._connect() cur = con.cursor() cur.execute("select application_name from pg_stat_activity" - " where application_name = %s", (application_name,)) + " where application_name = %s", (application_name,)) self.assertEqual(cur.fetchone(), (application_name,)) def test_connect_kwargs_with_special_chars(self): @@ -94,7 +93,7 @@ def test_connect_kwargs_with_special_chars(self): con = self._connect() cur = con.cursor() cur.execute("select application_name from pg_stat_activity" - " where application_name = %s", (special_name,)) + " where application_name = %s", (special_name,)) self.assertEqual(cur.fetchone(), (special_name,)) def test_percent_sign(self): @@ -145,7 +144,9 @@ def test_callproc_two_params(self): def test_cursor_type(self): class TestCursor(pgdb.Cursor): - pass + @staticmethod + def row_factory(row): + return row # not used con = self._connect() self.assertIs(con.cursor_type, pgdb.Cursor) @@ -191,6 +192,7 @@ def row_factory(self, row): def test_build_row_factory(self): + # noinspection PyAbstractClass class TestCursor(pgdb.Cursor): def build_row_factory(self): @@ -215,6 +217,7 @@ def build_row_factory(self): self.assertIsInstance(res[1], dict) self.assertEqual(res[1], {'a': 3, 'b': 4}) + # noinspection PyUnresolvedReferences def test_cursor_with_named_columns(self): con = self._connect() cur = con.cursor() @@ -238,6 +241,7 @@ def test_cursor_with_named_columns(self): self.assertEqual(res[1], (3, 4)) self.assertEqual(res[1]._fields, ('one', 'two')) + # noinspection PyUnresolvedReferences def test_cursor_with_unnamed_columns(self): con = self._connect() cur = con.cursor() @@ -252,6 +256,7 @@ def test_cursor_with_unnamed_columns(self): self.assertEqual(res, (1, 2, 3)) self.assertEqual(res._fields, ('one', '_1', 'three')) + # noinspection PyUnresolvedReferences def test_cursor_with_badly_named_columns(self): con = self._connect() cur = con.cursor() @@ -290,6 +295,7 @@ def test_coltypes(self): self.assertIsInstance(types, list) self.assertEqual(types, ['int2', 'int4', 'int8']) + # noinspection PyUnresolvedReferences def test_description_fields(self): con = self._connect() cur = con.cursor() @@ -381,7 +387,7 @@ def test_type_cache_typecast(self): cur = con.cursor() type_cache = con.type_cache self.assertIs(type_cache.get_typecast('int4'), int) - cast_int = lambda v: 'int(%s)' % v + cast_int = lambda v: 'int(%s)' % v # noqa: E731 type_cache.set_typecast('int4', cast_int) query = 'select 2::int2, 4::int4, 8::int8' cur.execute(query) @@ -443,10 +449,10 @@ def test_cursor_invalidation(self): def test_fetch_2_rows(self): Decimal = pgdb.decimal_type() values = ('test', pgdb.Binary(b'\xff\x52\xb2'), - True, 5, 6, 5.7, Decimal('234.234234'), Decimal('75.45'), - pgdb.Date(2011, 7, 17), pgdb.Time(15, 47, 42), - pgdb.Timestamp(2008, 10, 20, 15, 25, 35), - pgdb.Interval(15, 31, 5), 7897234) + True, 5, 6, 5.7, Decimal('234.234234'), Decimal('75.45'), + pgdb.Date(2011, 7, 17), pgdb.Time(15, 47, 42), + pgdb.Timestamp(2008, 10, 20, 15, 25, 35), + pgdb.Interval(15, 31, 5), 7897234) table = self.table_prefix + 'booze' con = self._connect() try: @@ -536,6 +542,7 @@ def test_sqlstate(self): except pgdb.DatabaseError as error: self.assertTrue(isinstance(error, pgdb.DataError)) # the SQLSTATE error code for division by zero is 22012 + # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') def test_float(self): @@ -591,8 +598,7 @@ def test_datetime(self): "d date, t time, ts timestamp," "tz timetz, tsz timestamptz)" % table) for n in range(3): - values = [dt.date(), dt.time(), dt, - dt.time(), dt] + values = [dt.date(), dt.time(), dt, dt.time(), dt] values[3] = values[3].replace(tzinfo=pgdb.timezone.utc) values[4] = values[4].replace(tzinfo=pgdb.timezone.utc) if n == 0: # input as objects @@ -604,12 +610,13 @@ def test_datetime(self): t = (dt.hour, dt.minute, dt.second, dt.microsecond) z = (pgdb.timezone.utc,) params = [pgdb.Date(*d), pgdb.Time(*t), - pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), - pgdb.Timestamp(*(d + t + z))] + pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), + pgdb.Timestamp(*(d + t + z))] for datestyle in ('iso', 'postgres, mdy', 'postgres, dmy', 'sql, mdy', 'sql, dmy', 'german'): cur.execute("set datestyle to %s" % datestyle) if n != 1: + # noinspection PyUnboundLocalVariable cur.execute("select %s,%s,%s,%s,%s", params) row = cur.fetchone() self.assertEqual(row, tuple(values)) @@ -630,7 +637,7 @@ def test_datetime(self): self.assertEqual(d[4].type_code, pgdb.TIMESTAMP) row = cur.fetchone() self.assertEqual(row, tuple(values)) - cur.execute("delete from %s" % table) + cur.execute("truncate table %s" % table) finally: con.close() @@ -653,6 +660,7 @@ def test_interval(self): for intervalstyle in ('sql_standard ', 'postgres', 'postgres_verbose', 'iso_8601'): cur.execute("set intervalstyle to %s" % intervalstyle) + # noinspection PyUnboundLocalVariable cur.execute("insert into %s" " values (%%s)" % table, [param]) cur.execute("select * from %s" % table) @@ -664,14 +672,14 @@ def test_interval(self): self.assertEqual(tc, pgdb.INTERVAL) row = cur.fetchone() self.assertEqual(row, (td,)) - cur.execute("delete from %s" % table) + cur.execute("truncate table %s" % table) finally: con.close() def test_hstore(self): con = self._connect() + cur = con.cursor() try: - cur = con.cursor() cur.execute("select 'k=>v'::hstore") except pgdb.DatabaseError: try: @@ -775,8 +783,8 @@ def test_insert_record(self): table = self.table_prefix + 'booze' record = self.table_prefix + 'munch' con = self._connect() + cur = con.cursor() try: - cur = con.cursor() cur.execute("create type %s as (name varchar, age int)" % record) cur.execute("create table %s (n smallint, r %s)" % (table, record)) params = enumerate(values) @@ -853,7 +861,7 @@ def test_set_decimal_type(self): try: cur = con.cursor() # change decimal type globally to int - int_type = lambda v: int(float(v)) + int_type = lambda v: int(float(v)) # noqa: E731 self.assertTrue(pgdb.decimal_type(int_type) is int_type) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) @@ -884,7 +892,7 @@ def test_global_typecast(self): try: query = 'select 2::int2, 4::int4, 8::int8' self.assertIs(pgdb.get_typecast('int4'), int) - cast_int = lambda v: 'int(%s)' % v + cast_int = lambda v: 'int(%s)' % v # noqa: E731 pgdb.set_typecast('int4', cast_int) con = self._connect() try: @@ -971,32 +979,32 @@ def test_set_typecast_for_arrays(self): def test_unicode_with_utf8(self): table = self.table_prefix + 'booze' - input = u"He wes Leovenaðes sone — liðe him be Drihten" + s = u"He wes Leovenaðes sone — liðe him be Drihten" con = self._connect() + cur = con.cursor() try: - cur = con.cursor() cur.execute("create table %s (t text)" % table) try: cur.execute("set client_encoding=utf8") - cur.execute(u"select '%s'" % input) + cur.execute(u"select '%s'" % s) except Exception: self.skipTest("database does not support utf8") output1 = cur.fetchone()[0] - cur.execute("insert into %s values (%%s)" % table, (input,)) + cur.execute("insert into %s values (%%s)" % table, (s,)) cur.execute("select * from %s" % table) output2 = cur.fetchone()[0] - cur.execute("select t = '%s' from %s" % (input, table)) + cur.execute("select t = '%s' from %s" % (s, table)) output3 = cur.fetchone()[0] - cur.execute("select t = %%s from %s" % table, (input,)) + cur.execute("select t = %%s from %s" % table, (s,)) output4 = cur.fetchone()[0] finally: con.close() if str is bytes: # Python < 3.0 - input = input.encode('utf8') + s = s.encode('utf8') self.assertIsInstance(output1, str) - self.assertEqual(output1, input) + self.assertEqual(output1, s) self.assertIsInstance(output2, str) - self.assertEqual(output2, input) + self.assertEqual(output2, s) self.assertIsInstance(output3, bool) self.assertTrue(output3) self.assertIsInstance(output4, bool) @@ -1004,32 +1012,32 @@ def test_unicode_with_utf8(self): def test_unicode_with_latin1(self): table = self.table_prefix + 'booze' - input = u"Ehrt den König seine Würde, ehret uns der Hände Fleiß." + s = u"Ehrt den König seine Würde, ehret uns der Hände Fleiß." con = self._connect() try: cur = con.cursor() cur.execute("create table %s (t text)" % table) try: cur.execute("set client_encoding=latin1") - cur.execute(u"select '%s'" % input) + cur.execute(u"select '%s'" % s) except Exception: self.skipTest("database does not support latin1") output1 = cur.fetchone()[0] - cur.execute("insert into %s values (%%s)" % table, (input,)) + cur.execute("insert into %s values (%%s)" % table, (s,)) cur.execute("select * from %s" % table) output2 = cur.fetchone()[0] - cur.execute("select t = '%s' from %s" % (input, table)) + cur.execute("select t = '%s' from %s" % (s, table)) output3 = cur.fetchone()[0] - cur.execute("select t = %%s from %s" % table, (input,)) + cur.execute("select t = %%s from %s" % table, (s,)) output4 = cur.fetchone()[0] finally: con.close() if str is bytes: # Python < 3.0 - input = input.encode('latin1') + s = s.encode('latin1') self.assertIsInstance(output1, str) - self.assertEqual(output1, input) + self.assertEqual(output1, s) self.assertIsInstance(output2, str) - self.assertEqual(output2, input) + self.assertEqual(output2, s) self.assertIsInstance(output3, bool) self.assertTrue(output3) self.assertIsInstance(output4, bool) @@ -1067,8 +1075,8 @@ def test_literal(self): self.assertEqual(row, (value, 'hello')) def test_json(self): - inval = {"employees": - [{"firstName": "John", "lastName": "Doe", "age": 61}]} + inval = {"employees": [ + {"firstName": "John", "lastName": "Doe", "age": 61}]} table = self.table_prefix + 'booze' con = self._connect() try: @@ -1087,8 +1095,8 @@ def test_json(self): self.assertEqual(inval, outval) def test_jsonb(self): - inval = {"employees": - [{"firstName": "John", "lastName": "Doe", "age": 61}]} + inval = {"employees": [ + {"firstName": "John", "lastName": "Doe", "age": 61}]} table = self.table_prefix + 'booze' con = self._connect() try: @@ -1154,6 +1162,12 @@ def test_fetchmany_with_keep(self): finally: con.close() + def help_nextset_setUp(self, _cur): + pass # helper not needed + + def help_nextset_tearDown(self, _cur): + pass # helper not needed + def test_nextset(self): con = self._connect() cur = con.cursor() @@ -1218,7 +1232,7 @@ def test_connection_as_contextmanager(self): try: cur = con.cursor() if autocommit: - cur.execute("truncate %s" % table) + cur.execute("truncate table %s" % table) else: cur.execute( "create table %s (n smallint check(n!=4))" % table) @@ -1302,7 +1316,7 @@ def test_no_close(self): data = ('hello', 'world') con = self._connect() cur = con.cursor() - cur.build_row_factory = lambda: tuple + cur.build_row_factory = lambda: tuple # noqa: E731 cur.execute("select %s, %s", data) row = cur.fetchone() self.assertEqual(row, data) diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index f86c2bee..939a6828 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -12,7 +12,7 @@ import unittest -try: +try: # noinspection PyCompatibility from collections.abc import Iterable except ImportError: # Python < 3.3 from collections import Iterable @@ -27,14 +27,14 @@ dbport = 5432 try: - from .LOCAL_PyGreSQL import * + from .LOCAL_PyGreSQL import * # noqa: F401 except (ImportError, ValueError): try: - from LOCAL_PyGreSQL import * + from LOCAL_PyGreSQL import * # noqa: F401 except ImportError: pass -try: # noinspection PyUnresolvedReferences +try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences unicode except NameError: # Python >= 3.0 unicode = str @@ -281,6 +281,7 @@ def test_input_unicode(self): self.copy_from(u'43\tWürstel, Käse!') self.assertEqual(self.table_data, [(43, 'Würstel, Käse!')]) self.truncate_table() + # noinspection PyUnresolvedReferences self.copy_from(self.data_text.decode('utf-8')) self.check_table() @@ -300,7 +301,7 @@ def test_input_iterable_with_newlines(self): def test_input_iterable_bytes(self): self.copy_from(row.encode('utf-8') - for row in self.data_text.splitlines()) + for row in self.data_text.splitlines()) self.check_table() def test_sep(self): @@ -339,7 +340,7 @@ def test_columns(self): (1, None), (2, None), (3, 'Three'), (4, 'Four'), (5, 'Five')]) self.check_rowcount(5) self.assertRaises(pgdb.ProgrammingError, self.copy_from, - '6\t42', columns=['id', 'age']) + '6\t42', columns=['id', 'age']) self.check_rowcount(-1) def test_csv(self): @@ -468,6 +469,7 @@ def test_generator_unicode(self): self.assertEqual(len(rows), 3) rows = ''.join(rows) self.assertIsInstance(rows, unicode) + # noinspection PyUnresolvedReferences self.assertEqual(rows, self.data_text.decode('utf-8')) def test_rowcount_increment(self): diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index d1295a6b..94871ecd 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -15,10 +15,10 @@ dbport = 5432 try: - from .LOCAL_PyGreSQL import * + from .LOCAL_PyGreSQL import * # noqa: F401 except (ImportError, ValueError): try: - from LOCAL_PyGreSQL import * + from LOCAL_PyGreSQL import * # noqa: F401 except ImportError: pass diff --git a/tox.ini b/tox.ini index cbf1c161..680a94f7 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,13 @@ # config file for tox [tox] -envlist = py{27,35,36,37,38} +envlist = py{27,35,36,37,38},flake8 + +[testenv:flake8] +basepython = python3.8 +deps = flake8>=3.8,<4 +commands = + flake8 setup.py pg.py pgdb.py tests [testenv] commands = From 79ace01d7129b1667eb44095c6b78c6cd4e52140 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Jun 2020 20:37:18 +0200 Subject: [PATCH 020/194] Run code quality tests in Travis CI --- .travis.yml | 36 ++++++++++++++++++++++------- docs/conf.py | 29 +++++++++++++++++++++++ docs/contents/pg/adaptation.rst | 2 +- docs/contents/pg/connection.rst | 6 ++--- docs/contents/pg/db_types.rst | 2 +- docs/contents/pg/db_wrapper.rst | 7 +++--- docs/contents/pg/introduction.rst | 2 ++ docs/contents/pg/large_objects.rst | 2 +- docs/contents/pg/module.rst | 2 +- docs/contents/pg/notification.rst | 6 ++--- docs/contents/pg/query.rst | 14 +++++------ docs/contents/pgdb/adaptation.rst | 2 +- docs/contents/pgdb/connection.rst | 2 +- docs/contents/pgdb/cursor.rst | 8 +++---- docs/contents/pgdb/module.rst | 2 +- docs/contents/pgdb/typecache.rst | 2 +- docs/contents/pgdb/types.rst | 2 +- docs/contents/postgres/advanced.rst | 2 +- docs/contents/postgres/basic.rst | 2 +- docs/contents/postgres/func.rst | 2 +- docs/contents/postgres/syscat.rst | 2 +- docs/contents/tutorial.rst | 4 ++-- tox.ini | 10 +++++++- 23 files changed, 103 insertions(+), 45 deletions(-) diff --git a/.travis.yml b/.travis.yml index 4bb9dd99..5173e569 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,17 +3,37 @@ language: python -python: - - "2.7" - - "3.5" - - "3.6" - - "3.7" - - "3.8" +matrix: + include: + - name: Code quality tests + env: TOXENV=flake8,docs + python: 3.8 + - name: Unit tests with Python 3.8 + env: TOXENV=py38 + python: 3.8 + - name: Unit tests with Python 3.7 + env: TOXENV=py37 + python: 3.7 + - name: Unit tests with Python 3.6 + env: TOXENV=py36 + python: 3.6 + - name: Unit tests with Python 3.5 + env: TOXENV=py35 + python: 3.5 + - name: Unit tests with Python 2.7 + env: TOXENV=py27 + python: 2.7 + +cache: + directories: + - "$HOME/.cache/pip" + - "$TRAVIS_BUILD_DIR/.tox" install: - - pip install . + - pip install tox-travis -script: python setup.py test +script: + - tox -e $TOXENV addons: postgresql: "10" diff --git a/docs/conf.py b/docs/conf.py index add72e63..cbe51f6d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -98,6 +98,35 @@ if use_cloud_theme: exclude_patterns += ['about.rst'] +# ignore certain warnings +# (references to some of the Python names do not resolve correctly) +nitpicky = True +nitpick_ignore = [ + ('py:' + t, n) for t, names in { + 'attr': ('arraysize', 'error', 'sqlstate', 'DatabaseError.sqlstate'), + 'class': ('bool', 'bytes', 'callable', 'class', + 'dict', 'float', 'function', 'int', 'iterable', + 'list', 'object', 'set', 'str', 'tuple', + 'False', 'True', 'None', + 'namedtuple', 'OrderedDict', 'decimal.Decimal', + 'bytes/str', 'list of namedtuples', 'tuple of callables', + 'type of first field', + 'Notice', 'DATETIME'), + 'data': ('defbase', 'defhost', 'defopt', 'defpasswd', 'defport', + 'defuser'), + 'exc': ('Exception', 'IOError', 'KeyError', 'MemoryError', + 'SyntaxError', 'TypeError', 'ValueError', + 'pg.InternalError', 'pg.InvalidResultError', + 'pg.MultipleResultsError', 'pg.NoResultError', + 'pg.OperationalError', 'pg.ProgrammingError'), + 'func': ('len', 'json.dumps', 'json.loads'), + 'meth': ('datetime.strptime', + 'cur.execute', + 'DB.close', 'DB.connection_handler', 'DB.get_regtypes', + 'DB.inserttable', 'DB.reopen'), + 'obj': ('False', 'True', 'None') + }.items() for n in names] + # The reST default role (used for this markup: `text`) for all documents. #default_role = None diff --git a/docs/contents/pg/adaptation.rst b/docs/contents/pg/adaptation.rst index 6ed6e779..8c09be23 100644 --- a/docs/contents/pg/adaptation.rst +++ b/docs/contents/pg/adaptation.rst @@ -1,7 +1,7 @@ Remarks on Adaptation and Typecasting ===================================== -.. py:currentmodule:: pg +.. currentmodule:: pg Both PostgreSQL and Python have the concept of data types, but there are of course differences between the two type systems. Therefore PyGreSQL diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 8556c5d2..56a8fe4e 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -1,7 +1,7 @@ Connection -- The connection object =================================== -.. py:currentmodule:: pg +.. currentmodule:: pg .. class:: Connection @@ -478,8 +478,8 @@ locreate -- create a large object in the database [LO] This method creates a large object in the database. The mode can be defined by OR-ing the constants defined in the :mod:`pg` module (:const:`INV_READ`, -:const:`INV_WRITE` and :const:`INV_ARCHIVE`). Please refer to PostgreSQL -user manual for a description of the mode values. +and :const:`INV_WRITE`). Please refer to PostgreSQL user manual for a +description of the mode values. getlo -- build a large object from given oid [LO] ------------------------------------------------- diff --git a/docs/contents/pg/db_types.rst b/docs/contents/pg/db_types.rst index 3318fd06..2119ecd3 100644 --- a/docs/contents/pg/db_types.rst +++ b/docs/contents/pg/db_types.rst @@ -1,7 +1,7 @@ DbTypes -- The internal cache for database types ================================================ -.. py:currentmodule:: pg +.. currentmodule:: pg .. class:: DbTypes diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 540871fc..9cef63d4 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -1,7 +1,7 @@ The DB wrapper class ==================== -.. py:currentmodule:: pg +.. currentmodule:: pg .. class:: DB @@ -48,8 +48,7 @@ You can also initialize the DB class with an existing :mod:`pg` or :mod:`pgdb` connection. Pass this connection as a single unnamed parameter, or as a single parameter named ``db``. This allows you to use all of the methods of the DB class with a DB-API 2 compliant connection. Note that the -:meth:`Connection.close` and :meth:`Connection.reopen` methods are inoperative -in this case. +:meth:`DB.close` and :meth:`DB.reopen` methods are inoperative in this case. pkey -- return the primary key of a table ----------------------------------------- @@ -799,7 +798,7 @@ escape_literal/identifier/string/bytea -- escape for SQL -------------------------------------------------------- The following methods escape text or binary strings so that they can be -inserted directly into an SQL command. Except for :meth:`DB.escape_byte`, +inserted directly into an SQL command. Except for :meth:`DB.escape_bytea`, you don't need to call these methods for the strings passed as parameters to :meth:`DB.query`. You also don't need to call any of these methods when storing data using :meth:`DB.insert` and similar. diff --git a/docs/contents/pg/introduction.rst b/docs/contents/pg/introduction.rst index 6a4ca7b8..1e369e12 100644 --- a/docs/contents/pg/introduction.rst +++ b/docs/contents/pg/introduction.rst @@ -1,6 +1,8 @@ Introduction ============ +.. currentmodule:: pg + You may either choose to use the "classic" PyGreSQL interface provided by the :mod:`pg` module or else the newer DB-API 2.0 compliant interface provided by the :mod:`pgdb` module. diff --git a/docs/contents/pg/large_objects.rst b/docs/contents/pg/large_objects.rst index d195eb4c..3efa5d3b 100644 --- a/docs/contents/pg/large_objects.rst +++ b/docs/contents/pg/large_objects.rst @@ -1,7 +1,7 @@ LargeObject -- Large Objects ============================ -.. py:currentmodule:: pg +.. currentmodule:: pg .. class:: LargeObject diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index 4a1cbea1..adb04845 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -1,7 +1,7 @@ Module functions and constants ============================== -.. py:currentmodule:: pg +.. currentmodule:: pg The :mod:`pg` module defines a few functions that allow to connect to a database and to define "default variables" that override diff --git a/docs/contents/pg/notification.rst b/docs/contents/pg/notification.rst index a37df668..05b04a16 100644 --- a/docs/contents/pg/notification.rst +++ b/docs/contents/pg/notification.rst @@ -1,7 +1,7 @@ The Notification Handler ======================== -.. py:currentmodule:: pg +.. currentmodule:: pg PyGreSQL comes with a client-side asynchronous notification handler that was based on the ``pgnotify`` module written by Ng Pheng Siong. @@ -25,7 +25,7 @@ Instantiating the notification handler :param str stop_event: an optional different name to be used as stop event You can also create an instance of the NotificationHandler using the -:class:`DB.connection_handler` method. In this case you don't need to +:meth:`DB.connection_handler` method. In this case you don't need to pass a database connection because the :class:`DB` connection itself will be used as the datebase connection for the notification handler. @@ -116,4 +116,4 @@ or when it is closed or deleted. You can call this method instead of :meth:`NotificationHandler.unlisten` if you want to close not only the handler, but also the database connection -it was created with. \ No newline at end of file +it was created with. diff --git a/docs/contents/pg/query.rst b/docs/contents/pg/query.rst index 311efaff..be764c74 100644 --- a/docs/contents/pg/query.rst +++ b/docs/contents/pg/query.rst @@ -1,7 +1,7 @@ Query methods ============= -.. py:currentmodule:: pg +.. currentmodule:: pg .. class:: Query @@ -204,7 +204,7 @@ It returns None if the result does not contain one more row. Get one row from the result of a query as named tuple :returns: next row from the query results as a named tuple - :rtype: named tuple or None + :rtype: namedtuple or None :raises TypeError: too many (any) parameters :raises MemoryError: internal memory error @@ -244,7 +244,7 @@ single/singledict/singlenamed/singlescalar -- get single result of a query :returns: single row from the query results as a tuple of fields :rtype: tuple - :raises InvalidResultError: result does not have exactly one row + :raises pg.InvalidResultError: result does not have exactly one row :raises TypeError: too many (any) parameters :raises MemoryError: internal memory error @@ -263,7 +263,7 @@ is empty and of type :exc:`pg.MultipleResultsError` if it has multiple rows. :returns: single row from the query results as a dictionary :rtype: dict - :raises InvalidResultError: result does not have exactly one row + :raises pg.InvalidResultError: result does not have exactly one row :raises TypeError: too many (any) parameters :raises MemoryError: internal memory error @@ -282,8 +282,8 @@ is empty and of type :exc:`pg.MultipleResultsError` if it has multiple rows. Get single row from the result of a query as named tuple :returns: single row from the query results as a named tuple - :rtype: named tuple - :raises InvalidResultError: result does not have exactly one row + :rtype: namedtuple + :raises pg.InvalidResultError: result does not have exactly one row :raises TypeError: too many (any) parameters :raises MemoryError: internal memory error @@ -306,7 +306,7 @@ is empty and of type :exc:`pg.MultipleResultsError` if it has multiple rows. :returns: single row from the query results as a scalar value :rtype: type of first field - :raises InvalidResultError: result does not have exactly one row + :raises pg.InvalidResultError: result does not have exactly one row :raises TypeError: too many (any) parameters :raises MemoryError: internal memory error diff --git a/docs/contents/pgdb/adaptation.rst b/docs/contents/pgdb/adaptation.rst index 1295b44f..0f9ad5a6 100644 --- a/docs/contents/pgdb/adaptation.rst +++ b/docs/contents/pgdb/adaptation.rst @@ -1,7 +1,7 @@ Remarks on Adaptation and Typecasting ===================================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb Both PostgreSQL and Python have the concept of data types, but there are of course differences between the two type systems. Therefore PyGreSQL diff --git a/docs/contents/pgdb/connection.rst b/docs/contents/pgdb/connection.rst index 958108b7..71492847 100644 --- a/docs/contents/pgdb/connection.rst +++ b/docs/contents/pgdb/connection.rst @@ -1,7 +1,7 @@ Connection -- The connection object =================================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb .. class:: Connection diff --git a/docs/contents/pgdb/cursor.rst b/docs/contents/pgdb/cursor.rst index a2ac63e8..52d600e8 100644 --- a/docs/contents/pgdb/cursor.rst +++ b/docs/contents/pgdb/cursor.rst @@ -1,7 +1,7 @@ Cursor -- The cursor object =========================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb .. class:: Cursor @@ -150,7 +150,7 @@ fetchone -- fetch next row of the query result Fetch the next row of a query result set :returns: the next row of the query result set - :rtype: named tuple or None + :rtype: namedtuple or None Fetch the next row of a query result set, returning a single named tuple, or ``None`` when no more data is available. The field names of the named @@ -176,7 +176,7 @@ fetchmany -- fetch next set of rows of the query result :param keep: if set to true, will keep the passed arraysize :tpye keep: bool :returns: the next set of rows of the query result - :rtype: list of named tuples + :rtype: list of namedtuples Fetch the next set of rows of a query result, returning a list of named tuples. An empty sequence is returned when no more rows are available. @@ -212,7 +212,7 @@ fetchall -- fetch all rows of the query result Fetch all (remaining) rows of a query result :returns: the set of all rows of the query result - :rtype: list of named tuples + :rtype: list of namedtuples Fetch all (remaining) rows of a query result, returning them as list of named tuples. The field names of the named tuple are the same as the column diff --git a/docs/contents/pgdb/module.rst b/docs/contents/pgdb/module.rst index 884ac4dc..5220193c 100644 --- a/docs/contents/pgdb/module.rst +++ b/docs/contents/pgdb/module.rst @@ -1,7 +1,7 @@ Module functions and constants ============================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb The :mod:`pgdb` module defines a :func:`connect` function that allows to connect to a database, some global constants describing the capabilities diff --git a/docs/contents/pgdb/typecache.rst b/docs/contents/pgdb/typecache.rst index a8b203ab..f0861a23 100644 --- a/docs/contents/pgdb/typecache.rst +++ b/docs/contents/pgdb/typecache.rst @@ -1,7 +1,7 @@ TypeCache -- The internal cache for database types ================================================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb .. class:: TypeCache diff --git a/docs/contents/pgdb/types.rst b/docs/contents/pgdb/types.rst index 0c13ec6b..f28e23f7 100644 --- a/docs/contents/pgdb/types.rst +++ b/docs/contents/pgdb/types.rst @@ -1,7 +1,7 @@ Type -- Type objects and constructors ===================================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb .. _type_constructors: diff --git a/docs/contents/postgres/advanced.rst b/docs/contents/postgres/advanced.rst index 38c8a473..e3e2ab10 100644 --- a/docs/contents/postgres/advanced.rst +++ b/docs/contents/postgres/advanced.rst @@ -1,7 +1,7 @@ Examples for advanced features ============================== -.. py:currentmodule:: pg +.. currentmodule:: pg In this section, we show how to use some advanced features of PostgreSQL using the classic PyGreSQL interface. diff --git a/docs/contents/postgres/basic.rst b/docs/contents/postgres/basic.rst index e6973442..b137351e 100644 --- a/docs/contents/postgres/basic.rst +++ b/docs/contents/postgres/basic.rst @@ -1,7 +1,7 @@ Basic examples ============== -.. py:currentmodule:: pg +.. currentmodule:: pg In this section, we demonstrate how to use some of the very basic features of PostgreSQL using the classic PyGreSQL interface. diff --git a/docs/contents/postgres/func.rst b/docs/contents/postgres/func.rst index b35e5ff7..9d0f5967 100644 --- a/docs/contents/postgres/func.rst +++ b/docs/contents/postgres/func.rst @@ -1,7 +1,7 @@ Examples for using SQL functions ================================ -.. py:currentmodule:: pg +.. currentmodule:: pg We assume that you have already created a connection to the PostgreSQL database, as explained in the :doc:`basic`:: diff --git a/docs/contents/postgres/syscat.rst b/docs/contents/postgres/syscat.rst index 13740203..80718afb 100644 --- a/docs/contents/postgres/syscat.rst +++ b/docs/contents/postgres/syscat.rst @@ -1,7 +1,7 @@ Examples for using the system catalogs ====================================== -.. py:currentmodule:: pg +.. currentmodule:: pg The system catalogs are regular tables where PostgreSQL stores schema metadata, such as information about tables and columns, and internal bookkeeping diff --git a/docs/contents/tutorial.rst b/docs/contents/tutorial.rst index 0ce05430..15577ad3 100644 --- a/docs/contents/tutorial.rst +++ b/docs/contents/tutorial.rst @@ -11,7 +11,7 @@ with both flavors of the PyGreSQL interface. Please choose your flavor: First Steps with the classic PyGreSQL Interface ----------------------------------------------- -.. py:currentmodule:: pg +.. currentmodule:: pg Before doing anything else, it's necessary to create a database connection. @@ -190,7 +190,7 @@ For more advanced features and details, see the reference: :doc:`pg/index` First Steps with the DB-API 2.0 Interface ----------------------------------------- -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb As with the classic interface, the first thing you need to do is to create a database connection. To do this, use the function :func:`pgdb.connect` diff --git a/tox.ini b/tox.ini index 680a94f7..bd5f222e 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py{27,35,36,37,38},flake8 +envlist = py{27,35,36,37,38},flake8,docs [testenv:flake8] basepython = python3.8 @@ -9,6 +9,14 @@ deps = flake8>=3.8,<4 commands = flake8 setup.py pg.py pgdb.py tests +[testenv:docs] +basepython = python3.8 +deps = + sphinx>=2.4,<3 + cloud_sptheme>=1.10,<2 +commands = + sphinx-build -b html -nEW docs docs/_build/html + [testenv] commands = python setup.py clean --all build_ext --force --inplace --strict From fbde45eecbae95ce129d6cfee2d36f3e970b04ab Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Jun 2020 21:00:32 +0200 Subject: [PATCH 021/194] Fix setUpClass in test_classic --- tests/test_classic.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_classic.py b/tests/test_classic.py index f106e0a4..15db5060 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -39,8 +39,8 @@ def open_db(): class UtilityTest(unittest.TestCase): @classmethod - def setupClass(cls): - """Drop test tables""" + def setUpClass(cls): + """Recreate test tables and schemas""" db = open_db() try: db.query("DROP VIEW _test_vschema") @@ -50,13 +50,17 @@ def setupClass(cls): db.query("DROP TABLE _test_schema") except Exception: pass - db.query("CREATE TABLE _test_schema " - "(_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)") + db.query("CREATE TABLE _test_schema" + " (_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)") db.query("CREATE VIEW _test_vschema AS" " SELECT _test, 'abc'::text AS _test2 FROM _test_schema") for t in ('_test1', '_test2'): try: - db.query("DROP SCHEMA %s CASCADE") + db.query("CREATE SCHEMA " + t) + except Exception: + pass + try: + db.query("DROP TABLE %s._test_schema" % (t,)) except Exception: pass db.query("CREATE TABLE %s._test_schema" From 856457dc60d7602c4682f05d1e980826151c2efe Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Jun 2020 21:56:00 +0200 Subject: [PATCH 022/194] Support bump2version Needs bump2version > 1.0 for the duplicate file sections (not yet available). But at least the various places where the version appears are now documented. --- .bumpversion.cfg | 33 +++++++++++++++++++++++++++++++++ docs/conf.py | 4 +--- 2 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 .bumpversion.cfg diff --git a/.bumpversion.cfg b/.bumpversion.cfg new file mode 100644 index 00000000..31f0ca1f --- /dev/null +++ b/.bumpversion.cfg @@ -0,0 +1,33 @@ +[bumpversion] +current_version = 5.2 +commit = False +tag = False + +parse = (?P\d+)\.(?P\d+)(?:\.(?P\d+))? +serialize = + {major}.{minor}.{patch} + {major}.{minor} + +[bumpversion:file:setup.py] +search = version = '{current_version}' +replace = version = '{new_version}' + +[bumpversion:file (head):setup.py] +search = PyGreSQL version {current_version} +replace = PyGreSQL version {new_version} + +[bumpversion:file:docs/conf.py] +search = version = release = '{current_version}' +replace = version = release = '{new_version}' + +[bumpversion:file:docs/about.txt] +search = PyGreSQL {current_version} +replace = PyGreSQL {new_version} + +[bumpversion:file:docs/announce.rst] +search = PyGreSQL version {current_version} +replace = PyGreSQL version {new_version} + +[bumpversion:file (text):docs/announce.rst] +search = Release {current_version} of PyGreSQL +replace = Release {new_version} of PyGreSQL diff --git a/docs/conf.py b/docs/conf.py index cbe51f6d..3eae7b58 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -67,10 +67,8 @@ # |version| and |release|, also used in various other places throughout the # built documents. # -# The short X.Y version. -version = '5.2' # The full version, including alpha/beta/rc tags. -release = '5.2' +version = release = '5.2' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. From cc55191e9f9943e1daf887e52bbd4946387131ff Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Jun 2020 23:24:56 +0200 Subject: [PATCH 023/194] Allow passing types as string in format_query. --- docs/contents/changelog.rst | 2 ++ pg.py | 9 ++++++++- tests/test_classic_dbwrapper.py | 11 +++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index af038add..942e036e 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -16,6 +16,8 @@ Version 5.2 (to be released) instead of Exception, as required by the DB-API 2 compliance test. - Connection arguments containing single quotes caused problems (reported and fixed by Tyler Ramer and Jamie McAtamney). + - The `types` parameer of `format_query` can now be passed as a string + that will be split on whitespace when values are passed as a sequence. Version 5.1.2 (2020-04-19) -------------------------- diff --git a/pg.py b/pg.py index 94cfa653..2a9a75a0 100644 --- a/pg.py +++ b/pg.py @@ -686,7 +686,12 @@ def parameter_list(self): return params def format_query(self, command, values=None, types=None, inline=False): - """Format a database query using the given values and types.""" + """Format a database query using the given values and types. + + The optional types describe the values and must be passed as a list, + tuple or string (that will be split on whitespace) when values are + passed as a list or tuple, or as a dict if values are passed as a dict. + """ if not values: return command, [] if inline and types: @@ -699,6 +704,8 @@ def format_query(self, command, values=None, types=None, inline=False): else: add = params.add if types: + if isinstance(types, basestring): + types = types.split() if (not isinstance(types, (list, tuple)) or len(types) != len(values)): raise TypeError('The values and types do not match') diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 5e0cd73b..a19c1ee0 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4343,6 +4343,17 @@ def testAdaptQueryTypedList(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + def testAdaptQueryTypedListWithString(self): + format_query = self.adapter.format_query + self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), ('int2',)) + self.assertRaises( + TypeError, format_query, '%s,%s', (1,), ('int2', 'int2')) + values = (3, 7.5, 'hello', True) + types = 'int4 float4 text bool' # pass types as list + sql, params = format_query("select %s,%s,%s,%s", values, types) + self.assertEqual(sql, 'select $1,$2,$3,$4') + self.assertEqual(params, [3, 7.5, 'hello', 't']) + def testAdaptQueryTypedDict(self): format_query = self.adapter.format_query self.assertRaises( From e771d715e30769421ce40f649754721bd9de57a5 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 19 Jun 2020 01:58:10 +0200 Subject: [PATCH 024/194] Allow specifying types as Python classes (#38) --- docs/contents/changelog.rst | 6 +- docs/contents/pg/db_wrapper.rst | 18 +++++ pg.py | 130 +++++++++++++++++++++----------- pgdb.py | 2 + tests/test_classic_dbwrapper.py | 89 +++++++++++++++++++++- 5 files changed, 197 insertions(+), 48 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 942e036e..993b037a 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -16,8 +16,10 @@ Version 5.2 (to be released) instead of Exception, as required by the DB-API 2 compliance test. - Connection arguments containing single quotes caused problems (reported and fixed by Tyler Ramer and Jamie McAtamney). - - The `types` parameer of `format_query` can now be passed as a string - that will be split on whitespace when values are passed as a sequence. + - The `types` parameter of `format_query` can now be passed as a string + that will be split on whitespace when values are passed as a sequence, + and the types can now also be specified using actual Python types + instead of type names (#38, suggested by Justin Pryzby). Version 5.1.2 (2020-04-19) -------------------------- diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 9cef63d4..92ef1ad9 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -514,6 +514,24 @@ Example:: "update employees set phone=%(phone)s where name=%(name)s", dict(name=name, phone=phone)).getresult()[0][0] +Example with specification of types:: + + db.query_formatted( + "update orders set info=%s where id=%s", + ({'customer': 'Joe', 'product': 'beer'}, 'id': 7), + types=('json', 'int')) + # or + db.query_formatted( + "update orders set info=%s where id=%s", + ({'customer': 'Joe', 'product': 'beer'}, 'id': 7), + types=('json int')) + # or + db.query_formatted( + "update orders set info=%(info)s where id=%(id)s", + {'info': {'customer': 'Joe', 'product': 'beer'}, 'id': 7}, + types={'info': 'json', 'id': 'int'}) + + query_prepared -- execute a prepared statement ---------------------------------------------- diff --git a/pg.py b/pg.py index 2a9a75a0..e45abc7f 100644 --- a/pg.py +++ b/pg.py @@ -87,6 +87,11 @@ except NameError: # Python >= 3.0 long = int +try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable + unicode +except NameError: # Python >= 3.0 + unicode = str + try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable basestring except NameError: # Python >= 3.0 @@ -253,10 +258,51 @@ def _oid_key(table): return 'oid(%s)' % table +class Bytea(bytes): + """Wrapper class for marking Bytea values.""" + + +class Hstore(dict): + """Wrapper class for marking hstore values.""" + + _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') + + @classmethod + def _quote(cls, s): + if s is None: + return 'NULL' + if not isinstance(s, basestring): + s = str(s) + if not s: + return '""' + s = s.replace('"', '\\"') + if cls._re_quote.search(s): + s = '"%s"' % s + return s + + def __str__(self): + q = self._quote + return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items()) + + +class Json: + """Wrapper class for marking Json values.""" + + def __init__(self, obj, encode=None): + self.obj = obj + self.encode = encode or jsonencode + + def __str__(self): + obj = self.obj + if isinstance(obj, basestring): + return obj + return self.encode(obj) + + class _SimpleTypes(dict): """Dictionary mapping pg_type names to simple type names.""" - _types = { + _type_strings = { 'bool': 'bool', 'bytea': 'bytea', 'date': 'date interval time timetz timestamp timestamptz' @@ -267,9 +313,20 @@ class _SimpleTypes(dict): 'num': 'numeric', 'money': 'money', 'text': 'bpchar char name text varchar'} + _type_classes = { + bool: 'bool', float: 'float', int: 'int', + bytes: 'text' if bytes is str else 'bytea', unicode: 'text', + date: 'date', time: 'date', datetime: 'date', timedelta: 'date', + Decimal: 'num', Bytea: 'bytea', Json: 'json', Hstore: 'hstore', + } + + if long is not int: + _type_classes[long] = 'num' + # noinspection PyMissingConstructor def __init__(self): - for typ, keys in self._types.items(): + self.update(self._type_classes) + for typ, keys in self._type_strings.items(): for key in keys.split(): self[key] = typ self['_%s' % key] = '%s[]' % typ @@ -312,38 +369,6 @@ def add(self, value, typ=None): return '$%d' % len(self) -class Bytea(bytes): - """Wrapper class for marking Bytea values.""" - - -class Hstore(dict): - """Wrapper class for marking hstore values.""" - - _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') - - @classmethod - def _quote(cls, s): - if s is None: - return 'NULL' - if not s: - return '""' - s = s.replace('"', '\\"') - if cls._re_quote.search(s): - s = '"%s"' % s - return s - - def __str__(self): - q = self._quote - return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items()) - - -class Json: - """Wrapper class for marking Json values.""" - - def __init__(self, obj): - self.obj = obj - - class Literal(str): """Wrapper class for marking literal SQL values.""" @@ -427,8 +452,22 @@ def _adapt_json(self, v): return None if isinstance(v, basestring): return v + if isinstance(v, Json): + return str(v) return self.db.encode_json(v) + def _adapt_hstore(self, v): + """Adapt a hstore parameter.""" + if not v: + return None + if isinstance(v, basestring): + return v + if isinstance(v, Hstore): + return str(v) + if isinstance(v, dict): + return str(Hstore(v)) + raise TypeError('Hstore parameter %s has wrong type' % v) + @classmethod def _adapt_text_array(cls, v): """Adapt a text type array parameter.""" @@ -588,8 +627,6 @@ def guess_simple_type(cls, value): return cls._frequent_simple_types[type(value)] except KeyError: pass - if isinstance(value, Bytea): - return 'bytea' if isinstance(value, basestring): return 'text' if isinstance(value, bool): @@ -602,6 +639,12 @@ def guess_simple_type(cls, value): return 'num' if isinstance(value, (date, time, datetime, timedelta)): return 'date' + if isinstance(value, Bytea): + return 'bytea' + if isinstance(value, Json): + return 'json' + if isinstance(value, Hstore): + return 'hstore' if isinstance(value, list): return '%s[]' % (cls.guess_simple_base_type(value) or 'text',) if isinstance(value, tuple): @@ -638,12 +681,6 @@ def adapt_inline(self, value, nested=False): value = self.db.escape_bytea(value) if bytes is not str: # Python >= 3.0 value = value.decode('ascii') - elif isinstance(value, Json): - # noinspection PyUnresolvedReferences - if value.encode: - # noinspection PyUnresolvedReferences - return value.encode() - value = self.db.encode_json(value) elif isinstance(value, (datetime, date, time, timedelta)): value = str(value) if isinstance(value, basestring): @@ -666,6 +703,12 @@ def adapt_inline(self, value, nested=False): if isinstance(value, tuple): q = self.adapt_inline return '(%s)' % ','.join(str(q(v)) for v in value) + if isinstance(value, Json): + value = self.db.escape_string(str(value)) + return "'%s'::json" % value + if isinstance(value, Hstore): + value = self.db.escape_string(str(value)) + return "'%s'::hstore" % value pg_repr = getattr(value, '__pg_repr__', None) if not pg_repr: raise InterfaceError( @@ -691,6 +734,9 @@ def format_query(self, command, values=None, types=None, inline=False): The optional types describe the values and must be passed as a list, tuple or string (that will be split on whitespace) when values are passed as a list or tuple, or as a dict if values are passed as a dict. + + If inline is set to True, then parameters will be passed inline + together with the query string. """ if not values: return command, [] diff --git a/pgdb.py b/pgdb.py index 9016c3cb..ec94ceb5 100644 --- a/pgdb.py +++ b/pgdb.py @@ -1855,6 +1855,8 @@ class Hstore(dict): def _quote(cls, s): if s is None: return 'NULL' + if not isinstance(s, basestring): + s = str(s) if not s: return '""' quote = cls._re_quote.search(s) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index a19c1ee0..da753b9c 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4343,17 +4343,64 @@ def testAdaptQueryTypedList(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) - def testAdaptQueryTypedListWithString(self): + def testAdaptQueryTypedListWithTypesAsString(self): format_query = self.adapter.format_query - self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), ('int2',)) + self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), 'int2') self.assertRaises( - TypeError, format_query, '%s,%s', (1,), ('int2', 'int2')) + TypeError, format_query, '%s,%s', (1,), 'int2 int2') values = (3, 7.5, 'hello', True) - types = 'int4 float4 text bool' # pass types as list + types = 'int4 float4 text bool' # pass types as string sql, params = format_query("select %s,%s,%s,%s", values, types) self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) + def testAdaptQueryTypedListWithTypesAsClasses(self): + format_query = self.adapter.format_query + self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), (int,)) + self.assertRaises( + TypeError, format_query, '%s,%s', (1,), (int, int)) + values = (3, 7.5, 'hello', True) + types = (int, float, str, bool) # pass types as classes + sql, params = format_query("select %s,%s,%s,%s", values, types) + self.assertEqual(sql, 'select $1,$2,$3,$4') + self.assertEqual(params, [3, 7.5, 'hello', 't']) + + def testAdaptQueryTypedListWithJson(self): + format_query = self.adapter.format_query + value = {'test': [1, "it's fine", 3]} + sql, params = format_query("select %s", (value,), 'json') + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) + value = pg.Json({'test': [1, "it's fine", 3]}) + sql, params = format_query("select %s", (value,), 'json') + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) + value = {'test': [1, "it's fine", 3]} + sql, params = format_query("select %s", [value], [pg.Json]) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) + + def testAdaptQueryTypedWithHstore(self): + format_query = self.adapter.format_query + value = {'one': "it's fine", 'two': 2} + sql, params = format_query("select %s", (value,), 'hstore') + self.assertEqual(sql, "select $1") + if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict + params[0] = ','.join(sorted(params[0].split(','))) + self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) + value = pg.Hstore({'one': "it's fine", 'two': 2}) + sql, params = format_query("select %s", (value,), 'hstore') + self.assertEqual(sql, "select $1") + if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict + params[0] = ','.join(sorted(params[0].split(','))) + self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) + value = pg.Hstore({'one': "it's fine", 'two': 2}) + sql, params = format_query("select %s", [value], [pg.Hstore]) + self.assertEqual(sql, "select $1") + if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict + params[0] = ','.join(sorted(params[0].split(','))) + self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) + def testAdaptQueryTypedDict(self): format_query = self.adapter.format_query self.assertRaises( @@ -4423,6 +4470,22 @@ def testAdaptQueryUntypedList(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + def testAdaptQueryUntypedListWithJson(self): + format_query = self.adapter.format_query + value = pg.Json({'test': [1, "it's fine", 3]}) + sql, params = format_query("select %s", (value,)) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) + + def testAdaptQueryUntypedWithHstore(self): + format_query = self.adapter.format_query + value = pg.Hstore({'one': "it's fine", 'two': 2}) + sql, params = format_query("select %s", (value,)) + self.assertEqual(sql, "select $1") + if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict + params[0] = ','.join(sorted(params[0].split(','))) + self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) + def testAdaptQueryUntypedDict(self): format_query = self.adapter.format_query values = dict(i=3, f=7.5, t='hello', b=True) @@ -4478,6 +4541,24 @@ def testAdaptQueryInlineList(self): sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) + def testAdaptQueryInlineListWithJson(self): + format_query = self.adapter.format_query + value = pg.Json({'test': [1, "it's fine", 3]}) + sql, params = format_query("select %s", (value,), inline=True) + self.assertEqual( + sql, "select '{\"test\": [1, \"it''s fine\", 3]}'::json") + self.assertEqual(params, []) + + def testAdaptQueryInlineListWithHstore(self): + format_query = self.adapter.format_query + value = pg.Hstore({'one': "it's fine", 'two': 2}) + sql, params = format_query("select %s", (value,), inline=True) + if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict + sql = sql[:8] + ','.join(sorted(sql[8:-9].split(','))) + sql[-9:] + self.assertEqual( + sql, "select 'one=>\"it''s fine\",two=>2'::hstore") + self.assertEqual(params, []) + def testAdaptQueryInlineDict(self): format_query = self.adapter.format_query values = dict(i=3, f=7.5, t='hello', b=True) From 9e41b301e0f8201e3dd534127092cea1117d37f9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 19 Jun 2020 02:01:12 +0200 Subject: [PATCH 025/194] Python 2.6 needs not be considered in tests anymore --- tests/test_classic_dbwrapper.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index da753b9c..49edf374 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -1010,11 +1010,6 @@ def testQueryFormatted(self): q = f("select %(a)s, %(b)s, %(c)s, %(d)s", dict(a=3, b=2.5, c='hello', d=True), inline=True) r = q.getresult()[0] - if isinstance(r[1], Decimal): - # Python 2.6 cannot compare float and Decimal - r = list(r) - r[1] = float(r[1]) - r = tuple(r) self.assertEqual(r, (3, 2.5, 'hello', t)) # test with dict and extra values q = f("select %(a)s||%(b)s||%(c)s||%(d)s||'epsilon'", @@ -4829,7 +4824,7 @@ def getLeaks(self, fut): objs[:] = gc.get_objects() objs[:] = [obj for obj in objs if id(obj) not in ids] if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)): - # workaround for Python issue 26811 + # workaround for Python 3.5 issue 26811 objs[:] = [obj for obj in objs if repr(obj) != '(,)'] self.assertEqual(len(objs), 0) From 78d336af8c9e3913e14d76d2922c36327c052385 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 19 Jun 2020 13:06:14 +0200 Subject: [PATCH 026/194] Remove redundant type mapping Also, allow List[type] as type alias for arrays. --- pg.py | 89 +++++++++++++++++++++++++++++++---------------------------- 1 file changed, 47 insertions(+), 42 deletions(-) diff --git a/pg.py b/pg.py index e45abc7f..80852e08 100644 --- a/pg.py +++ b/pg.py @@ -82,6 +82,13 @@ from json import loads as jsondecode, dumps as jsonencode from uuid import UUID +try: + # noinspection PyUnresolvedReferences + from typing import Dict, List, Union + has_typing = True +except ImportError: # Python < 3.5 + has_typing = False + try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable long except NameError: # Python >= 3.0 @@ -300,43 +307,54 @@ def __str__(self): class _SimpleTypes(dict): - """Dictionary mapping pg_type names to simple type names.""" - - _type_strings = { - 'bool': 'bool', - 'bytea': 'bytea', - 'date': 'date interval time timetz timestamp timestamptz' - ' abstime reltime', # these are very old - 'float': 'float4 float8', - 'int': 'cid int2 int4 int8 oid xid', - 'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid', - 'num': 'numeric', 'money': 'money', - 'text': 'bpchar char name text varchar'} - - _type_classes = { - bool: 'bool', float: 'float', int: 'int', - bytes: 'text' if bytes is str else 'bytea', unicode: 'text', - date: 'date', time: 'date', datetime: 'date', timedelta: 'date', - Decimal: 'num', Bytea: 'bytea', Json: 'json', Hstore: 'hstore', - } - - if long is not int: - _type_classes[long] = 'num' + """Dictionary mapping pg_type names to simple type names. + + The corresponding Python types and simple names are also mapped. + """ + + _type_aliases = { + 'bool': [bool], + 'bytea': [Bytea], + 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', + 'abstime', 'reltime', # these are very old + 'datetime', 'timedelta', # these do not really exist + date, time, datetime, timedelta], + 'float': ['float4', 'float8', float], + 'int': ['cid', 'int2', 'int4', 'int8', 'oid', 'xid', int], + 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], + 'num': ['numeric', Decimal], 'money': [], + 'text': ['bpchar', 'char', 'name', 'varchar', + bytes, unicode, basestring] + } # type: Dict[str, List[Union[str, type]]] + + if long is not int: # Python 2 has a separate long type + _type_aliases['num'].append(long) # noinspection PyMissingConstructor def __init__(self): - self.update(self._type_classes) - for typ, keys in self._type_strings.items(): - for key in keys.split(): + """Initialize type mapping.""" + for typ, keys in self._type_aliases.items(): + keys = [typ] + keys + for key in keys: self[key] = typ - self['_%s' % key] = '%s[]' % typ + if isinstance(key, str): + self['_%s' % key] = '%s[]' % typ + elif has_typing and not isinstance(key, tuple): + self[List[key]] = '%s[]' % typ @staticmethod def __missing__(key): + """Unmapped types are interpreted as text.""" return 'text' + def get_type_dict(self): + """Get a plain dictionary of only the types.""" + return dict((key, typ) for key, typ in self.items() + if not isinstance(key, (str, tuple))) + _simpletypes = _SimpleTypes() +_simple_type_dict = _simpletypes.get_type_dict() def _quote_if_unqualified(param, name): @@ -604,27 +622,12 @@ def get_attnames(typ): return typ.attnames return {} - _frequent_simple_types = { - Bytea: 'bytea', - str: 'text', - bytes: 'text', - bool: 'bool', - int: 'int', - long: 'int', - float: 'float', - Decimal: 'num', - date: 'date', - time: 'date', - datetime: 'date', - timedelta: 'date' - } - @classmethod def guess_simple_type(cls, value): """Try to guess which database type the given value has.""" # optimize for most frequent types try: - return cls._frequent_simple_types[type(value)] + return _simple_type_dict[type(value)] except KeyError: pass if isinstance(value, basestring): @@ -645,6 +648,8 @@ def guess_simple_type(cls, value): return 'json' if isinstance(value, Hstore): return 'hstore' + if isinstance(value, UUID): + return 'uuid' if isinstance(value, list): return '%s[]' % (cls.guess_simple_base_type(value) or 'text',) if isinstance(value, tuple): From 5b57de33f5218f26ce7cef21fa7846bb61a6db13 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 19 Jun 2020 22:58:46 +0200 Subject: [PATCH 027/194] Support asynchronous command processing (#19) --- docs/contents/changelog.rst | 6 +- docs/contents/pg/connection.rst | 180 +++++++++++++++++ docs/contents/pg/module.rst | 15 +- docs/contents/pg/query.rst | 9 + pg.py | 1 + pgconn.c | 268 +++++++++++++++++++------ pgmodule.c | 68 +++++-- pgquery.c | 330 ++++++++++++++++++++++--------- tests/test_classic_connection.py | 105 +++++++++- tests/test_classic_dbwrapper.py | 8 +- 10 files changed, 813 insertions(+), 177 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 993b037a..856bd995 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -10,6 +10,10 @@ Version 5.2 (to be released) of the pqlib used by PyGreSQL (needs PostgreSQL >= 9.1 on the client). - New query method `memsize()` that gets the memory size allocated by the query (needs PostgreSQL >= 12 on the client). + - Experimental support for asynchronous command processing. + Additional connection parameter ``nowait``, and connection methods + `send_query()`, `poll()`, `set_non_blocking()`, `is_non_blocking()`. + Generously contributed by Patrick TJ McPhee (#19). - Changes to the DB-API 2 module (pgdb): - When using Python 2, errors are now derived from StandardError @@ -19,7 +23,7 @@ Version 5.2 (to be released) - The `types` parameter of `format_query` can now be passed as a string that will be split on whitespace when values are passed as a sequence, and the types can now also be specified using actual Python types - instead of type names (#38, suggested by Justin Pryzby). + instead of type names. Suggested by Justin Pryzby (#38). Version 5.1.2 (2020-04-19) -------------------------- diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 56a8fe4e..f342c028 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -81,6 +81,102 @@ Example:: phone = con.query("select phone from employees where name=$1", (name,)).getresult() + +send_query - executes a SQL command string asynchronously +--------------------------------------------------------- + +.. method:: Connection.send_query(command, [args]) + + Submits a command to the server without waiting for the result(s). + + :param str command: SQL command + :param args: optional parameter values + :returns: a query object, as described below + :rtype: :class:`Query` + :raises TypeError: bad argument type, or too many arguments + :raises TypeError: invalid connection + :raises ValueError: empty SQL query or lost connection + :raises pg.ProgrammingError: error in query + +This method is much the same as :meth:`Connection.query`, except that it +returns without waiting for the query to complete. The database connection +cannot be used for other operations until the query completes, but the +application can do other things, including executing queries using other +database connections. The application can call ``select()`` using the +``fileno``` obtained by the connection#s :meth:`Connection.fileno` method +to determine when the query has results to return. + +This method always returns a :class:`Query` object. This object differs +from the :class:`Query` object returned by :meth:`Connection.query` in a +few way. Most importantly, when :meth:`Connection.send_query` is used, the +application must call one of the result-returning methods such as +:meth:`Query.getresult` or :meth:`Query.dictresult` until it either raises +an exception or returns ``None``. + +Otherwise, the database connection will be left in an unusable state. + +In cases when :meth:`Connection.query` would return something other than +a :class:`Query` object, that result will be returned by calling one of +the result-returning methods on the :class:`Query` object returned by +:meth:`Connection.send_query`. There's one important difference in these +result codes: if :meth:`Connection.query` returns `None`, the result-returning +methods will return an empty string (`''`). It's still necessary to call a +result-returning method until it returns `None`. + +:meth:`Query.listfields`, :meth:`Query.fieldname`, :meth:`Query.fieldnum`, +and :meth:`Query.ntuples` only work after a call to a result-returning method +with a non-`None` return value. :meth:`Query.ntuples` returns only the number +of rows returned by the previous result-returning method. + +If multiple semi-colon-delimited statements are passed to +:meth:`Connection.query`, only the results of the last statement are returned +in the :class:`Query` object. With :meth:`Connection.send_query`, all results +are returned. Each result set will be returned by a separate call to +:meth:`Query.getresult()` or other result-returning methods. + +.. versionadded:: 5.2 + +Examples:: + + name = input("Name? ") + query = con.send_query("select phone from employees where name=$1", + (name,)) + phone = query.getresult() + query.getresult() # to close the query + + # Run two queries in one round trip: + # (Note that you cannot use a union here + # when the result sets have different row types.) + query = con.send_query("select a,b,c from x where d=e; + "select e,f from y where g") + result_x = query.dictresult() + result_y = query.dictresult() + query.dictresult() # to close the query + + # Using select() to wait for the query to be ready: + query = con.send_query("select pg_sleep(20)") + r, w, e = select([con.fileno(), other, sockets], [], []) + if con.fileno() in r: + results = query.getresult() + query.getresult() # to close the query + + # Concurrent queries on separate connections: + con1 = connect() + con2 = connect() + s = con1.query("begin; set transaction isolation level repeatable read;" + "select pg_export_snapshot();").getresult()[0][0] + con2.query("begin; set transaction isolation level repeatable read;" + "set transaction snapshot '%s'" % (s,)) + q1 = con1.send_query("select a,b,c from x where d=e") + q2 = con2.send_query("select e,f from y where g") + r1 = q1.getresult() + q1.getresult() + r2 = q2.getresult() + q2.getresult() + con1.query("commit") + con2.query("commit") + + query_prepared -- execute a prepared statement ---------------------------------------------- @@ -169,6 +265,56 @@ reset -- reset the connection This method resets the current database connection. +poll - completes an asynchronous connection +------------------------------------------- + +.. method:: Connection.poll() + + Complete an asynchronous :mod:`pg` connection and get its state + + :returns: state of the connection + :rtype: int + :raises TypeError: too many (any) arguments + :raises TypeError: invalid connection + :raises pg.InternalError: some error occurred during pg connection + +The database connection can be performed without any blocking calls. +This allows the application mainline to perform other operations or perhaps +connect to multiple databases concurrently. Once the connection is established, +it's no different from a connection made using blocking calls. + +The required steps are to pass the parameter ``nowait=True`` to the +:meth:`pg.connect` call, then call :meth:`Connection.poll` until it either +returns :const':`POLLING_OK` or raises an exception. To avoid blocking +in :meth:`Connection.poll`, use `select()` or `poll()` to wait for the +connection to be readable or writable, depending on the return code of the +previous call to :meth:`Connection.poll`. The initial state of the connection +is :const:`POLLING_WRITING`. The possible states are defined as constants in +the :mod:`pg` module (:const:`POLLING_OK`, :const:`POLLING_FAILED`, +:const:`POLLING_READING` and :const:`POLLING_WRITING`). + +.. versionadded:: 5.2 + +Example:: + + con = pg.connect('testdb', nowait=True) + fileno = con.fileno() + rd = [] + wt = [fileno] + rc = pg.POLLING_WRITING + while rc not in (pg.POLLING_OK, pg.POLLING_FAILED): + ra, wa, xa = select(rd, wt, [], timeout) + if not ra and not wa: + timedout() + rc = con.poll() + if rc == pg.POLLING_READING: + rd = [fileno] + wt = [] + else: + rd = [] + wt = [fileno] + + cancel -- abandon processing of current SQL command --------------------------------------------------- @@ -281,6 +427,40 @@ fileno -- get the socket used to connect to the database This method returns the underlying socket id used to connect to the database. This is useful for use in select calls, etc. +set_non_blocking - set the non-blocking status of the connection +---------------------------------------------------------------- + +.. method:: set_non_blocking(nb) + + Set the non-blocking mode of the connection + + :param bool nb: True to put the connection into non-blocking mode. + False to put it into blocking mode. + :raises TypeError: too many parameters + :raises TypeError: invalid connection + +Puts the socket connection into non-blocking mode or into blocking mode. +This affects copy commands and large object operations, but not queries. + +.. versionadded:: 5.2 + +is_non_blocking - report the blocking status of the connection +-------------------------------------------------------------- + +.. method:: is_non_blocking() + + get the non-blocking mode of the connection + + :returns: True if the connection is in non-blocking mode. + False if it is in blocking mode. + :rtype: bool + :raises TypeError: too many parameters + :raises TypeError: invalid connection + +Returns True if the connection is in non-blocking mode, False otherwise. + +.. versionadded:: 5.2 + getnotify -- get the last notify from the server ------------------------------------------------ diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index adb04845..b122808b 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -20,7 +20,7 @@ standard environment variables should be used. connect -- Open a PostgreSQL connection --------------------------------------- -.. function:: connect([dbname], [host], [port], [opt], [user], [passwd]) +.. function:: connect([dbname], [host], [port], [opt], [user], [passwd], [nowait]) Open a :mod:`pg` connection @@ -36,6 +36,8 @@ connect -- Open a PostgreSQL connection :type user: str or None :param passwd: password for user (*None* = :data:`defpasswd`) :type passwd: str or None + :param nowait: whether the connection should happen asynchronously + :type nowait: bool :returns: If successful, the :class:`Connection` handling the connection :rtype: :class:`Connection` :raises TypeError: bad argument type, or too many arguments @@ -49,11 +51,15 @@ Python tutorial. The names of the keywords are the name of the parameters given in the syntax line. The ``opt`` parameter can be used to pass command-line options to the server. For a precise description of the parameters, please refer to the PostgreSQL user manual. +See :meth:`Connection.poll` for a description of the ``nowait`` parameter. If you want to add additional parameters not specified here, you must pass a connection string or a connection URI instead of the ``dbname`` (as in ``con3`` and ``con4`` in the following example). +.. versionchanged:: 5.2 + Support for asynchronous connections via the ``nowait`` parameter. + Example:: import pg @@ -747,6 +753,13 @@ for more information about them. These constants are: large objects access modes, used by :meth:`Connection.locreate` and :meth:`LargeObject.open` +.. data:: POLLING_OK +.. data:: POLLING_FAILED +.. data:: POLLING_READING +.. data:: POLLING_WRITING + + polling states, returned by :meth:`Connection.poll` + .. data:: SEEK_SET .. data:: SEEK_CUR .. data:: SEEK_END diff --git a/docs/contents/pg/query.rst b/docs/contents/pg/query.rst index be764c74..4fcf46fb 100644 --- a/docs/contents/pg/query.rst +++ b/docs/contents/pg/query.rst @@ -43,6 +43,9 @@ You can also call :func:`len` on a query to find the number of rows in the result, and access row tuples using their index directly on the :class:`Query` object. +When the :class:`Query` object was returned by :meth:`Connection.send_query`, +other return values are also possible, as documented there. + dictresult/dictiter -- get query values as dictionaries ------------------------------------------------------- @@ -81,6 +84,9 @@ fetched from the server anyway when the query is executed. If the query has duplicate field names, you will get the value for the field with the highest index in the query. +When the :class:`Query` object was returned by :meth:`Connection.send_query`, +other return values are also possible, as documented there. + .. versionadded:: 5.1 namedresult/namediter -- get query values as named tuples @@ -127,6 +133,9 @@ Column names in the database that are not valid as field names for named tuples (particularly, names starting with an underscore) are automatically renamed to valid positional names. +When the :class:`Query` object was returned by :meth:`Connection.send_query`, +other return values are also possible, as documented there. + .. versionadded:: 5.1 scalarresult/scalariter -- get query values as scalars diff --git a/pg.py b/pg.py index 80852e08..cf996cf0 100644 --- a/pg.py +++ b/pg.py @@ -52,6 +52,7 @@ 'NoResultError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', 'INV_READ', 'INV_WRITE', + 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', 'TRANS_ACTIVE', 'TRANS_IDLE', 'TRANS_INERROR', 'TRANS_INTRANS', 'TRANS_UNKNOWN', diff --git a/pgconn.c b/pgconn.c index e16fd68e..f8c81e28 100644 --- a/pgconn.c +++ b/pgconn.c @@ -160,9 +160,57 @@ conn_source(connObject *self, PyObject *noargs) return (PyObject *) source_obj; } -/* Base method for execution of both unprepared and prepared queries */ +/* For a non-query result, set the appropriate error status, + return the appropriate value, and free the result set. */ static PyObject * -_conn_query(connObject *self, PyObject *args, int prepared) +_conn_non_query_result(int status, PGresult* result, PGconn *cnx) +{ + switch (status) { + case PGRES_EMPTY_QUERY: + PyErr_SetString(PyExc_ValueError, "Empty query"); + break; + case PGRES_BAD_RESPONSE: + case PGRES_FATAL_ERROR: + case PGRES_NONFATAL_ERROR: + set_error(ProgrammingError, "Cannot execute query", + cnx, result); + break; + case PGRES_COMMAND_OK: + { /* INSERT, UPDATE, DELETE */ + Oid oid = PQoidValue(result); + + if (oid == InvalidOid) { /* not a single insert */ + char *ret = PQcmdTuples(result); + + if (ret[0]) { /* return number of rows affected */ + PyObject *obj = PyStr_FromString(ret); + PQclear(result); + return obj; + } + PQclear(result); + Py_INCREF(Py_None); + return Py_None; + } + /* for a single insert, return the oid */ + PQclear(result); + return PyInt_FromLong(oid); + } + case PGRES_COPY_OUT: /* no data will be received */ + case PGRES_COPY_IN: + PQclear(result); + Py_INCREF(Py_None); + return Py_None; + default: + set_error_msg(InternalError, "Unknown result status"); + } + + PQclear(result); + return NULL; /* error detected on query */ + } + +/* Base method for execution of all different kinds of queries */ +static PyObject * +_conn_query(connObject *self, PyObject *args, int prepared, int async) { PyObject *query_str_obj, *param_obj = NULL; PGresult* result; @@ -282,11 +330,19 @@ _conn_query(connObject *self, PyObject *args, int prepared) } Py_BEGIN_ALLOW_THREADS - result = prepared ? - PQexecPrepared(self->cnx, query, nparms, - parms, NULL, NULL, 0) : - PQexecParams(self->cnx, query, nparms, - NULL, parms, NULL, NULL, 0); + if (async) { + status = PQsendQueryParams(self->cnx, query, nparms, + NULL, (const char * const *)parms, NULL, NULL, 0); + result = NULL; + } + else { + result = prepared ? + PQexecPrepared(self->cnx, query, nparms, + parms, NULL, NULL, 0) : + PQexecParams(self->cnx, query, nparms, + NULL, parms, NULL, NULL, 0); + status = result != NULL; + } Py_END_ALLOW_THREADS PyMem_Free((void *) parms); @@ -295,10 +351,17 @@ _conn_query(connObject *self, PyObject *args, int prepared) } else { Py_BEGIN_ALLOW_THREADS - result = prepared ? - PQexecPrepared(self->cnx, query, 0, - NULL, NULL, NULL, 0) : - PQexec(self->cnx, query); + if (async) { + status = PQsendQuery(self->cnx, query); + result = NULL; + } + else { + result = prepared ? + PQexecPrepared(self->cnx, query, 0, + NULL, NULL, NULL, 0) : + PQexec(self->cnx, query); + status = result != NULL; + } Py_END_ALLOW_THREADS } @@ -307,7 +370,7 @@ _conn_query(connObject *self, PyObject *args, int prepared) Py_XDECREF(param_obj); /* checks result validity */ - if (!result) { + if (!status) { PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); return NULL; } @@ -317,49 +380,8 @@ _conn_query(connObject *self, PyObject *args, int prepared) self->date_format = date_format; /* this is normally NULL */ /* checks result status */ - if ((status = PQresultStatus(result)) != PGRES_TUPLES_OK) { - switch (status) { - case PGRES_EMPTY_QUERY: - PyErr_SetString(PyExc_ValueError, "Empty query"); - break; - case PGRES_BAD_RESPONSE: - case PGRES_FATAL_ERROR: - case PGRES_NONFATAL_ERROR: - set_error(ProgrammingError, "Cannot execute query", - self->cnx, result); - break; - case PGRES_COMMAND_OK: - { /* INSERT, UPDATE, DELETE */ - Oid oid = PQoidValue(result); - - if (oid == InvalidOid) { /* not a single insert */ - char *ret = PQcmdTuples(result); - - if (ret[0]) { /* return number of rows affected */ - PyObject *obj = PyStr_FromString(ret); - PQclear(result); - return obj; - } - PQclear(result); - Py_INCREF(Py_None); - return Py_None; - } - /* for a single insert, return the oid */ - PQclear(result); - return PyInt_FromLong(oid); - } - case PGRES_COPY_OUT: /* no data will be received */ - case PGRES_COPY_IN: - PQclear(result); - Py_INCREF(Py_None); - return Py_None; - default: - set_error_msg(InternalError, "Unknown result status"); - } - - PQclear(result); - return NULL; /* error detected on query */ - } + if (result && (status = PQresultStatus(result)) != PGRES_TUPLES_OK) + return _conn_non_query_result(status, result, self->cnx); if (!(query_obj = PyObject_New(queryObject, &queryType))) return PyErr_NoMemory(); @@ -368,15 +390,23 @@ _conn_query(connObject *self, PyObject *args, int prepared) Py_XINCREF(self); query_obj->pgcnx = self; query_obj->result = result; + query_obj->async = async; query_obj->encoding = encoding; query_obj->current_row = 0; - query_obj->max_row = PQntuples(result); - query_obj->num_fields = PQnfields(result); - query_obj->col_types = get_col_types(result, query_obj->num_fields); - if (!query_obj->col_types) { - Py_DECREF(query_obj); - Py_DECREF(self); - return NULL; + if (async) { + query_obj->max_row = 0; + query_obj->num_fields = 0; + query_obj->col_types = NULL; + } + else { + query_obj->max_row = PQntuples(result); + query_obj->num_fields = PQnfields(result); + query_obj->col_types = get_col_types(result, query_obj->num_fields); + if (!query_obj->col_types) { + Py_DECREF(query_obj); + Py_DECREF(self); + return NULL; + } } return (PyObject *) query_obj; @@ -391,7 +421,19 @@ static char conn_query__doc__[] = static PyObject * conn_query(connObject *self, PyObject *args) { - return _conn_query(self, args, 0); + return _conn_query(self, args, 0, 0); +} + +/* Asynchronous database query */ +static char conn_send_query__doc__[] = +"send_query(sql, [arg]) -- create a new asynchronous query for this connection\n\n" +"You must pass the SQL (string) request and you can optionally pass\n" +"a tuple with positional parameters.\n"; + +static PyObject * +conn_send_query(connObject *self, PyObject *args) +{ + return _conn_query(self, args, 0, 1); } /* Execute prepared statement. */ @@ -403,7 +445,7 @@ static char conn_query_prepared__doc__[] = static PyObject * conn_query_prepared(connObject *self, PyObject *args) { - return _conn_query(self, args, 1); + return _conn_query(self, args, 1, 0); } /* Create prepared statement. */ @@ -583,6 +625,62 @@ conn_endcopy(connObject *self, PyObject *noargs) Py_INCREF(Py_None); return Py_None; } + +/* Direct access function: set blocking status. */ +static char conn_set_non_blocking__doc__[] = +"set_non_blocking() -- set the non-blocking status of the connection"; + +static PyObject * +conn_set_non_blocking(connObject *self, PyObject *args) +{ + int non_blocking; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + if (!PyArg_ParseTuple(args, "i", &non_blocking)) { + PyErr_SetString(PyExc_TypeError, "setnonblocking(tf), with boolean."); + return NULL; + } + + if (PQsetnonblocking(self->cnx, non_blocking) < 0) { + PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + +/* Direct access function: get blocking status. */ +static char conn_is_non_blocking__doc__[] = +"is_non_blocking() -- report the blocking status of the connection"; + +static PyObject * +conn_is_non_blocking(connObject *self, PyObject *args) +{ + int rc; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + if (!PyArg_ParseTuple(args, "")) { + PyErr_SetString(PyExc_TypeError, + "method is_non_blocking() takes no parameters"); + return NULL; + } + + rc = PQisnonblocking(self->cnx); + if (rc < 0) { + PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); + return NULL; + } + + return PyBool_FromLong(rc); +} #endif /* DIRECT_ACCESS */ @@ -1269,6 +1367,40 @@ conn_get_cast_hook(connObject *self, PyObject *noargs) return ret; } +/* Get asynchronous connection state. */ +static char conn_poll__doc__[] = +"poll() -- Completes an asynchronous connection"; + +static PyObject * +conn_poll(connObject *self, PyObject *args) +{ + int rc; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* check args */ + if (!PyArg_ParseTuple(args, "")) { + PyErr_SetString(PyExc_TypeError, + "method poll() takes no parameters"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + rc = PQconnectPoll(self->cnx); + Py_END_ALLOW_THREADS + + if (rc == PGRES_POLLING_FAILED) { + set_error(InternalError, "Polling failed", self->cnx, NULL); + Py_XDECREF(self); + return NULL; + } + + return PyInt_FromLong(rc); +} + /* Set notice receiver callback function. */ static char conn_set_notice_receiver__doc__[] = "set_notice_receiver(func) -- set the current notice receiver"; @@ -1417,12 +1549,16 @@ static struct PyMethodDef conn_methods[] = { METH_NOARGS, conn_source__doc__}, {"query", (PyCFunction) conn_query, METH_VARARGS, conn_query__doc__}, + {"send_query", (PyCFunction) conn_send_query, + METH_VARARGS, conn_send_query__doc__}, {"query_prepared", (PyCFunction) conn_query_prepared, METH_VARARGS, conn_query_prepared__doc__}, {"prepare", (PyCFunction) conn_prepare, METH_VARARGS, conn_prepare__doc__}, {"describe_prepared", (PyCFunction) conn_describe_prepared, METH_VARARGS, conn_describe_prepared__doc__}, + {"poll", (PyCFunction) conn_poll, + METH_VARARGS, conn_poll__doc__}, {"reset", (PyCFunction) conn_reset, METH_NOARGS, conn_reset__doc__}, {"cancel", (PyCFunction) conn_cancel, @@ -1468,6 +1604,10 @@ static struct PyMethodDef conn_methods[] = { METH_NOARGS, conn_getline__doc__}, {"endcopy", (PyCFunction) conn_endcopy, METH_NOARGS, conn_endcopy__doc__}, + {"set_non_blocking", (PyCFunction) conn_set_non_blocking, + METH_O, conn_set_non_blocking__doc__}, + {"is_non_blocking", (PyCFunction) conn_is_non_blocking, + METH_NOARGS, conn_is_non_blocking__doc__}, #endif /* DIRECT_ACCESS */ #ifdef LARGE_OBJECTS diff --git a/pgmodule.c b/pgmodule.c index a3d68c54..80c3e043 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -72,7 +72,7 @@ static PyObject *pg_default_passwd; /* default password */ #endif /* DEFAULT_VARS */ static PyObject *decimal = NULL, /* decimal type */ - *dictiter = NULL, /* function for getting named results */ + *dictiter = NULL, /* function for getting dict results */ *namediter = NULL, /* function for getting named results */ *namednext = NULL, /* function for getting one named result */ *scalariter = NULL, /* function for getting scalar results */ @@ -154,6 +154,7 @@ typedef struct PyObject_HEAD connObject *pgcnx; /* parent connection object */ PGresult *result; /* result content */ + int async; /* flag for asynchronous queries */ int encoding; /* client encoding */ int current_row; /* currently selected row */ int max_row; /* number of rows in the result */ @@ -197,7 +198,7 @@ typedef struct /* Connect to a database. */ static char pg_connect__doc__[] = -"connect(dbname, host, port, opt) -- connect to a PostgreSQL database\n\n" +"connect(dbname, host, port, opt, user, passwd, wait) -- connect to a PostgreSQL database\n\n" "The connection uses the specified parameters (optional, keywords aware).\n"; static PyObject * @@ -205,16 +206,17 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) { static const char *kwlist[] = { - "dbname", "host", "port", "opt", "user", "passwd", NULL + "dbname", "host", "port", "opt", "user", "passwd", "nowait", NULL }; char *pghost, *pgopt, *pgdbname, *pguser, *pgpasswd; - int pgport; + int pgport = -1, nowait = 0, nkw = 0; char port_buffer[20]; + const char *keywords[sizeof(kwlist) / sizeof(*kwlist) + 1], + *values[sizeof(kwlist) / sizeof(*kwlist) + 1]; connObject *conn_obj; pghost = pgopt = pgdbname = pguser = pgpasswd = NULL; - pgport = -1; /* * parses standard arguments With the right compiler warnings, this @@ -223,8 +225,8 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) * I try to assign all those constant strings to it. */ if (!PyArg_ParseTupleAndKeywords( - args, dict, "|zzizzz", (char**)kwlist, - &pgdbname, &pghost, &pgport, &pgopt, &pguser, &pgpasswd)) + args, dict, "|zzizzzi", (char**)kwlist, + &pgdbname, &pghost, &pgport, &pgopt, &pguser, &pgpasswd, &nowait)) { return NULL; } @@ -261,14 +263,44 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) conn_obj->cast_hook = NULL; conn_obj->notice_receiver = NULL; - if (pgport != -1) { + if (pghost) + { + keywords[nkw] = "host"; + values[nkw++] = pghost; + } + if (pgopt) + { + keywords[nkw] = "options"; + values[nkw++] = pgopt; + } + if (pgdbname) + { + keywords[nkw] = "dbname"; + values[nkw++] = pgdbname; + } + if (pguser) + { + keywords[nkw] = "user"; + values[nkw++] = pguser; + } + if (pgpasswd) + { + keywords[nkw] = "password"; + values[nkw++] = pgpasswd; + } + if (pgport != -1) + { memset(port_buffer, 0, sizeof(port_buffer)); sprintf(port_buffer, "%d", pgport); + + keywords[nkw] = "port"; + values[nkw++] = port_buffer; } + keywords[nkw] = values[nkw] = NULL; Py_BEGIN_ALLOW_THREADS - conn_obj->cnx = PQsetdbLogin(pghost, pgport == -1 ? NULL : port_buffer, - pgopt, NULL, pgdbname, pguser, pgpasswd); + conn_obj->cnx = nowait ? PQconnectStartParams(keywords, values, 1) : + PQconnectdbParams(keywords, values, 1); Py_END_ALLOW_THREADS if (PQstatus(conn_obj->cnx) == CONNECTION_BAD) { @@ -1321,11 +1353,17 @@ MODULE_INIT_FUNC(_pg) PyDict_SetItemString(dict, "RESULT_DQL", PyInt_FromLong(RESULT_DQL)); /* Transaction states */ - PyDict_SetItemString(dict,"TRANS_IDLE",PyInt_FromLong(PQTRANS_IDLE)); - PyDict_SetItemString(dict,"TRANS_ACTIVE",PyInt_FromLong(PQTRANS_ACTIVE)); - PyDict_SetItemString(dict,"TRANS_INTRANS",PyInt_FromLong(PQTRANS_INTRANS)); - PyDict_SetItemString(dict,"TRANS_INERROR",PyInt_FromLong(PQTRANS_INERROR)); - PyDict_SetItemString(dict,"TRANS_UNKNOWN",PyInt_FromLong(PQTRANS_UNKNOWN)); + PyDict_SetItemString(dict, "TRANS_IDLE", PyInt_FromLong(PQTRANS_IDLE)); + PyDict_SetItemString(dict, "TRANS_ACTIVE", PyInt_FromLong(PQTRANS_ACTIVE)); + PyDict_SetItemString(dict, "TRANS_INTRANS", PyInt_FromLong(PQTRANS_INTRANS)); + PyDict_SetItemString(dict, "TRANS_INERROR", PyInt_FromLong(PQTRANS_INERROR)); + PyDict_SetItemString(dict, "TRANS_UNKNOWN", PyInt_FromLong(PQTRANS_UNKNOWN)); + + /* Polling results */ + PyDict_SetItemString(dict, "POLLING_OK", PyInt_FromLong(PGRES_POLLING_OK)); + PyDict_SetItemString(dict, "POLLING_FAILED", PyInt_FromLong(PGRES_POLLING_FAILED)); + PyDict_SetItemString(dict, "POLLING_READING", PyInt_FromLong(PGRES_POLLING_READING)); + PyDict_SetItemString(dict, "POLLING_WRITING", PyInt_FromLong(PGRES_POLLING_WRITING)); #ifdef LARGE_OBJECTS /* Create mode for large objects */ diff --git a/pgquery.c b/pgquery.c index 54b1466f..fefd8e70 100644 --- a/pgquery.c +++ b/pgquery.c @@ -102,6 +102,89 @@ _query_row_as_tuple(queryObject *self) return row_tuple; } +/* Fetch the result if this is an asynchronous query and it has not yet + been fetched in this round-trip. Also mark whether the result should + be kept for this round-trip (e.g. to be used in an iterator). + If this is a normal query result, the query itself will be returned, + otherwise a result value will be returned that shall be passed on. */ +static PyObject * +_get_async_result(queryObject *self, int keep) { + int fetch = 0; + + if (self->async) { + if (self->async == 1) { + fetch = 1; + if (keep) { + /* mark query as fetched, do not fetch again */ + self->async = 2; + } + } else if (!keep) { + self->async = 1; + } + } + + if (fetch) { + int status; + + if (!self->pgcnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + if (self->result) { + PQclear(self->result); + } + self->result = PQgetResult(self->pgcnx->cnx); + Py_END_ALLOW_THREADS + if (!self->result) { + /* end of result set, return None */ + Py_DECREF(self->pgcnx); + self->pgcnx = NULL; + Py_INCREF(Py_None); + return Py_None; + } + + if ((status = PQresultStatus(self->result)) != PGRES_TUPLES_OK) { + PyObject* result = _conn_non_query_result( + status, self->result, self->pgcnx->cnx); + self->result = NULL; /* since this has been already cleared */ + if (!result) { + /* Raise an error. We need to call PQgetResult() to clear the + connection state. This should return NULL the first time. */ + self->result = PQgetResult(self->pgcnx->cnx); + while (self->result) { + PQclear(self->result); + self->result = PQgetResult(self->pgcnx->cnx); + Py_DECREF(self->pgcnx); + self->pgcnx = NULL; + } + } + else if (result == Py_None) { + /* It's would be confusing to return None here because the + caller has to call again until we return None. We can't + just consume that final None because we don't know if there + are additional statements following this one, so we return + an empty string where query() would return None. */ + Py_DECREF(result); + result = PyStr_FromString(""); + } + return result; + } + + self->max_row = PQntuples(self->result); + self->num_fields = PQnfields(self->result); + self->col_types = get_col_types(self->result, self->num_fields); + if (!self->col_types) { + Py_DECREF(self); + Py_DECREF(self); + return NULL; + } + } + /* return the query object itself as sentinel for a normal query result */ + return (PyObject *)self; +} + /* Return given item from a query object. */ static PyObject * query_getitem(PyObject *self, Py_ssize_t i) @@ -110,6 +193,9 @@ query_getitem(PyObject *self, Py_ssize_t i) PyObject *tmp; long row; + if ((tmp = _get_async_result(q, 0)) != (PyObject *)self) + return tmp; + tmp = PyLong_FromSize_t((size_t) i); row = PyLong_AsLong(tmp); Py_DECREF(tmp); @@ -127,6 +213,11 @@ query_getitem(PyObject *self, Py_ssize_t i) Returns the default iterator yielding rows as tuples. */ static PyObject* query_iter(queryObject *self) { + PyObject *res; + + if ((res = _get_async_result(self, 0)) != (PyObject *)self) + return res; + self->current_row = 0; Py_INCREF(self); return (PyObject*) self; @@ -261,12 +352,16 @@ query_one(queryObject *self, PyObject *noargs) { PyObject *row_tuple; - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + if ((row_tuple = _get_async_result(self, 0)) == (PyObject *)self) { + + if (self->current_row >= self->max_row) { + Py_INCREF(Py_None); return Py_None; + } + + row_tuple = _query_row_as_tuple(self); + if (row_tuple) ++self->current_row; } - row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; return row_tuple; } @@ -283,17 +378,21 @@ query_single(queryObject *self, PyObject *noargs) { PyObject *row_tuple; - if (self->max_row != 1) { - if (self->max_row) - set_error_msg(MultipleResultsError, "Multiple results found"); - else - set_error_msg(NoResultError, "No result found"); - return NULL; + if ((row_tuple = _get_async_result(self, 0)) == (PyObject *)self) { + + if (self->max_row != 1) { + if (self->max_row) + set_error_msg(MultipleResultsError, "Multiple results found"); + else + set_error_msg(NoResultError, "No result found"); + return NULL; + } + + self->current_row = 0; + row_tuple = _query_row_as_tuple(self); + if (row_tuple) ++self->current_row; } - self->current_row = 0; - row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; return row_tuple; } @@ -309,17 +408,20 @@ query_getresult(queryObject *self, PyObject *noargs) PyObject *result_list; int i; - if (!(result_list = PyList_New(self->max_row))) { - return NULL; - } + if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { - for (i = self->current_row = 0; i < self->max_row; ++i) { - PyObject *row_tuple = query_next(self, noargs); + if (!(result_list = PyList_New(self->max_row))) { + return NULL; + } + + for (i = self->current_row = 0; i < self->max_row; ++i) { + PyObject *row_tuple = query_next(self, noargs); - if (!row_tuple) { - Py_DECREF(result_list); return NULL; + if (!row_tuple) { + Py_DECREF(result_list); return NULL; + } + PyList_SET_ITEM(result_list, i, row_tuple); } - PyList_SET_ITEM(result_list, i, row_tuple); } return result_list; @@ -378,12 +480,16 @@ query_onedict(queryObject *self, PyObject *noargs) { PyObject *row_dict; - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + if ((row_dict = _get_async_result(self, 0)) == (PyObject *)self) { + + if (self->current_row >= self->max_row) { + Py_INCREF(Py_None); return Py_None; + } + + row_dict = _query_row_as_dict(self); + if (row_dict) ++self->current_row; } - row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; return row_dict; } @@ -401,17 +507,21 @@ query_singledict(queryObject *self, PyObject *noargs) { PyObject *row_dict; - if (self->max_row != 1) { - if (self->max_row) - set_error_msg(MultipleResultsError, "Multiple results found"); - else - set_error_msg(NoResultError, "No result found"); - return NULL; + if ((row_dict = _get_async_result(self, 0)) == (PyObject *)self) { + + if (self->max_row != 1) { + if (self->max_row) + set_error_msg(MultipleResultsError, "Multiple results found"); + else + set_error_msg(NoResultError, "No result found"); + return NULL; + } + + self->current_row = 0; + row_dict = _query_row_as_dict(self); + if (row_dict) ++self->current_row; } - self->current_row = 0; - row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; return row_dict; } @@ -427,17 +537,20 @@ query_dictresult(queryObject *self, PyObject *noargs) PyObject *result_list; int i; - if (!(result_list = PyList_New(self->max_row))) { - return NULL; - } + if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { + + if (!(result_list = PyList_New(self->max_row))) { + return NULL; + } - for (i = self->current_row = 0; i < self->max_row; ++i) { - PyObject *row_dict = query_next_dict(self, noargs); + for (i = self->current_row = 0; i < self->max_row; ++i) { + PyObject *row_dict = query_next_dict(self, noargs); - if (!row_dict) { - Py_DECREF(result_list); return NULL; + if (!row_dict) { + Py_DECREF(result_list); return NULL; + } + PyList_SET_ITEM(result_list, i, row_dict); } - PyList_SET_ITEM(result_list, i, row_dict); } return result_list; @@ -452,10 +565,15 @@ static char query_dictiter__doc__[] = static PyObject * query_dictiter(queryObject *self, PyObject *noargs) { + PyObject *res; + if (!dictiter) { return query_dictresult(self, noargs); } + if ((res = _get_async_result(self, 1)) != (PyObject *)self) + return res; + return PyObject_CallFunction(dictiter, "(O)", self); } @@ -469,10 +587,15 @@ static char query_onenamed__doc__[] = static PyObject * query_onenamed(queryObject *self, PyObject *noargs) { + PyObject *res; + if (!namednext) { return query_one(self, noargs); } + if ((res = _get_async_result(self, 1)) != (PyObject *)self) + return res; + if (self->current_row >= self->max_row) { Py_INCREF(Py_None); return Py_None; } @@ -491,10 +614,15 @@ static char query_singlenamed__doc__[] = static PyObject * query_singlenamed(queryObject *self, PyObject *noargs) { + PyObject *res; + if (!namednext) { return query_single(self, noargs); } + if ((res = _get_async_result(self, 1)) != (PyObject *)self) + return res; + if (self->max_row != 1) { if (self->max_row) set_error_msg(MultipleResultsError, "Multiple results found"); @@ -522,11 +650,15 @@ query_namedresult(queryObject *self, PyObject *noargs) return query_getresult(self, noargs); } - res = PyObject_CallFunction(namediter, "(O)", self); - if (!res) return NULL; - if (PyList_Check(res)) return res; - res_list = PySequence_List(res); - Py_DECREF(res); + if ((res_list = _get_async_result(self, 1)) == (PyObject *)self) { + + res = PyObject_CallFunction(namediter, "(O)", self); + if (!res) return NULL; + if (PyList_Check(res)) return res; + res_list = PySequence_List(res); + Py_DECREF(res); + } + return res_list; } @@ -545,11 +677,15 @@ query_namediter(queryObject *self, PyObject *noargs) return query_iter(self); } - res = PyObject_CallFunction(namediter, "(O)", self); - if (!res) return NULL; - if (!PyList_Check(res)) return res; - res_iter = (Py_TYPE(res)->tp_iter)((PyObject *) self); - Py_DECREF(res); + if ((res_iter = _get_async_result(self, 1)) == (PyObject *)self) { + + res = PyObject_CallFunction(namediter, "(O)", self); + if (!res) return NULL; + if (!PyList_Check(res)) return res; + res_iter = (Py_TYPE(res)->tp_iter)((PyObject *) self); + Py_DECREF(res); + } + return res_iter; } @@ -564,25 +700,28 @@ query_scalarresult(queryObject *self, PyObject *noargs) { PyObject *result_list; - if (!self->num_fields) { - set_error_msg(ProgrammingError, "No fields in result"); - return NULL; - } + if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { - if (!(result_list = PyList_New(self->max_row))) { - return NULL; - } + if (!self->num_fields) { + set_error_msg(ProgrammingError, "No fields in result"); + return NULL; + } + + if (!(result_list = PyList_New(self->max_row))) { + return NULL; + } - for (self->current_row = 0; - self->current_row < self->max_row; - ++self->current_row) - { - PyObject *value = _query_value_in_column(self, 0); + for (self->current_row = 0; + self->current_row < self->max_row; + ++self->current_row) + { + PyObject *value = _query_value_in_column(self, 0); - if (!value) { - Py_DECREF(result_list); return NULL; + if (!value) { + Py_DECREF(result_list); return NULL; + } + PyList_SET_ITEM(result_list, self->current_row, value); } - PyList_SET_ITEM(result_list, self->current_row, value); } return result_list; @@ -597,10 +736,15 @@ static char query_scalariter__doc__[] = static PyObject * query_scalariter(queryObject *self, PyObject *noargs) { + PyObject *res; + if (!scalariter) { return query_scalarresult(self, noargs); } + if ((res = _get_async_result(self, 1)) != (PyObject *)self) + return res; + if (!self->num_fields) { set_error_msg(ProgrammingError, "No fields in result"); return NULL; @@ -621,17 +765,21 @@ query_onescalar(queryObject *self, PyObject *noargs) { PyObject *value; - if (!self->num_fields) { - set_error_msg(ProgrammingError, "No fields in result"); - return NULL; - } + if ((value = _get_async_result(self, 0)) == (PyObject *)self) { - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + if (!self->num_fields) { + set_error_msg(ProgrammingError, "No fields in result"); + return NULL; + } + + if (self->current_row >= self->max_row) { + Py_INCREF(Py_None); return Py_None; + } + + value = _query_value_in_column(self, 0); + if (value) ++self->current_row; } - value = _query_value_in_column(self, 0); - if (value) ++self->current_row; return value; } @@ -648,22 +796,26 @@ query_singlescalar(queryObject *self, PyObject *noargs) { PyObject *value; - if (!self->num_fields) { - set_error_msg(ProgrammingError, "No fields in result"); - return NULL; - } + if ((value = _get_async_result(self, 0)) == (PyObject *)self) { - if (self->max_row != 1) { - if (self->max_row) - set_error_msg(MultipleResultsError, "Multiple results found"); - else - set_error_msg(NoResultError, "No result found"); - return NULL; + if (!self->num_fields) { + set_error_msg(ProgrammingError, "No fields in result"); + return NULL; + } + + if (self->max_row != 1) { + if (self->max_row) + set_error_msg(MultipleResultsError, "Multiple results found"); + else + set_error_msg(NoResultError, "No result found"); + return NULL; + } + + self->current_row = 0; + value = _query_value_in_column(self, 0); + if (value) ++self->current_row; } - self->current_row = 0; - value = _query_value_in_column(self, 0); - if (value) ++self->current_row; return value; } @@ -752,7 +904,7 @@ static PyTypeObject queryType = { 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT |Py_TPFLAGS_HAVE_ITER, /* tp_flags */ - query__doc__, /* tp_doc */ + query__doc__, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 9ce70684..c1f265ef 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -72,6 +72,12 @@ def connect(): return connection +def connect_nowait(): + """Start a basic pg connection in a non-blocking manner.""" + # noinspection PyArgumentList + return pg.connect(dbname, dbhost, dbport, nowait=True) + + class TestCanConnect(unittest.TestCase): """Test whether a basic connection to PostgreSQL is possible.""" @@ -85,6 +91,20 @@ def testCanConnect(self): except pg.Error: self.fail('Cannot close the database connection') + def testCanConnectNoWait(self): + try: + connection = connect() + rc = connection.poll() + while rc not in (pg.POLLING_OK, pg.POLLING_FAILED): + rc = connection.poll() + except pg.Error as error: + self.fail('Cannot connect to database %s:\n%s' % (dbname, error)) + self.assertEqual(rc, pg.POLLING_OK) + try: + connection.close() + except pg.Error: + self.fail('Cannot close the database connection') + class TestConnectObject(unittest.TestCase): """Test existence of basic pg connection methods.""" @@ -132,9 +152,10 @@ def testAllConnectMethods(self): cancel close date_format describe_prepared endcopy escape_bytea escape_identifier escape_literal escape_string fileno get_cast_hook get_notice_receiver getline getlo getnotify - inserttable locreate loimport parameter - prepare putline query query_prepared reset - set_cast_hook set_notice_receiver source transaction + inserttable is_non_blocking locreate loimport parameter poll + prepare putline query query_prepared reset send_query + set_cast_hook set_non_blocking set_notice_receiver + source transaction '''.split() connection_methods = [ a for a in dir(self.connection) @@ -220,6 +241,49 @@ def testMethodQuery(self): def testMethodQueryEmpty(self): self.assertRaises(ValueError, self.connection.query, '') + def testMethodSendQuerySingle(self): + query = self.connection.send_query + for q, args, result in ( + ("select 1+1 as a", (), 2), + ("select 1+$1 as a", ((1,),), 2), + ("select 1+$1+$2 as a", ((2, 3),), 6)): + pgq = query(q, *args) + self.assertEqual(self.connection.transaction(), pg.TRANS_ACTIVE) + self.assertEqual(pgq.getresult()[0][0], result) + self.assertEqual(self.connection.transaction(), pg.TRANS_ACTIVE) + self.assertIsNone(pgq.getresult()) + self.assertEqual(self.connection.transaction(), pg.TRANS_IDLE) + + pgq = query(q, *args) + self.assertEqual(pgq.namedresult()[0].a, result) + self.assertIsNone(pgq.namedresult()) + + pgq = query(q, *args) + self.assertEqual(pgq.dictresult()[0]['a'], result) + self.assertIsNone(pgq.dictresult()) + + def testMethodSendQueryMultiple(self): + query = self.connection.send_query + + pgq = query("select 1+1; select 'pg';") + self.assertEqual(pgq.getresult()[0][0], 2) + self.assertEqual(pgq.getresult()[0][0], 'pg') + self.assertIsNone(pgq.getresult()) + + pgq = query("select 1+1 as a; select 'pg' as a;") + self.assertEqual(pgq.namedresult()[0].a, 2) + self.assertEqual(pgq.namedresult()[0].a, 'pg') + self.assertIsNone(pgq.namedresult()) + + pgq = query("select 1+1 as a; select 'pg' as a;") + self.assertEqual(pgq.dictresult()[0]['a'], 2) + self.assertEqual(pgq.dictresult()[0]['a'], 'pg') + self.assertIsNone(pgq.dictresult()) + + def testMethodSendQueryEmpty(self): + query = self.connection.send_query('') + self.assertRaises(ValueError, query.getresult) + def testAllQueryMembers(self): query = self.connection.query("select true where false") members = ''' @@ -420,6 +484,18 @@ def testGetresultString(self): self.assertIsInstance(v, str) self.assertEqual(v, result) + def testGetresultAsync(self): + q = "select 0" + result = [(0,)] + query = self.c.send_query(q) + r = query.getresult() + self.assertIsInstance(r, list) + v = r[0] + self.assertIsInstance(v, tuple) + self.assertIsInstance(v[0], int) + self.assertEqual(r, result) + self.assertIsNone(query.getresult()) + def testDictresult(self): q = "select 0 as alias0" result = [{'alias0': 0}] @@ -452,6 +528,18 @@ def testDictresultString(self): self.assertIsInstance(v, str) self.assertEqual(v, result) + def testDictresultAsync(self): + q = "select 0 as alias0" + result = [{'alias0': 0}] + query = self.c.send_query(q) + r = query.dictresult() + self.assertIsInstance(r, list) + v = r[0] + self.assertIsInstance(v, dict) + self.assertIsInstance(v['alias0'], int) + self.assertEqual(r, result) + self.assertIsNone(query.dictresult()) + def testNamedresult(self): q = "select 0 as alias0" result = [(0,)] @@ -482,6 +570,17 @@ def testNamedresultWithBadFieldnames(self): self.assertEqual(v._fields[:6], fields) self.assertEqual(v._fields[6], 'and_a_good_one') + def testNamedresultAsync(self): + q = "select 0 as alias0" + query = self.c.send_query(q) + result = [(0,)] + r = query.namedresult() + self.assertEqual(r, result) + v = r[0] + self.assertEqual(v._fields, ('alias0',)) + self.assertEqual(v.alias0, 0) + self.assertIsNone(query.namedresult()) + def testGet3Cols(self): q = "select 1,2,3" result = [(1, 2, 3)] diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 49edf374..8704fbe5 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -202,16 +202,16 @@ def testAllDBAttributes(self): 'get_parameter', 'get_relations', 'get_tables', 'getline', 'getlo', 'getnotify', 'has_table_privilege', 'host', - 'insert', 'inserttable', + 'insert', 'inserttable', 'is_non_blocking', 'locreate', 'loimport', 'notification_handler', 'options', - 'parameter', 'pkey', 'port', + 'parameter', 'pkey', 'poll', 'port', 'prepare', 'protocol_version', 'putline', 'query', 'query_formatted', 'query_prepared', 'release', 'reopen', 'reset', 'rollback', - 'savepoint', 'server_version', - 'set_cast_hook', 'set_notice_receiver', + 'savepoint', 'send_query', 'server_version', + 'set_cast_hook', 'set_non_blocking', 'set_notice_receiver', 'set_parameter', 'socket', 'source', 'ssl_attributes', 'ssl_in_use', 'start', 'status', From 17f28cdbe7f53459919b1ca8143b14f06d1db7b1 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 19 Jun 2020 23:38:53 +0200 Subject: [PATCH 028/194] Add project URLs to setup --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index 1bd3efa7..a95b0a2b 100755 --- a/setup.py +++ b/setup.py @@ -212,6 +212,11 @@ def finalize_options(self): author_email="darcy@PyGreSQL.org", url="http://www.pygresql.org", download_url="http://www.pygresql.org/download/", + project_urls={ + "Documentation": "https://pygresql.org/contents/", + "Issue Tracker": "https://github.com/PyGreSQL/PyGreSQL/issues/", + "Mailing List": "https://mail.vex.net/mailman/listinfo/pygresql", + "Source Code": "https://github.com/PyGreSQL/PyGreSQL"}, platforms=["any"], license="PostgreSQL", py_modules=py_modules, From f82d4cd20dfa66ec791091581e60e299b329c70e Mon Sep 17 00:00:00 2001 From: Justin Pryzby Date: Fri, 19 Jun 2020 19:51:58 -0500 Subject: [PATCH 029/194] Add optional columns list to inserttable (#24) --- docs/contents/changelog.rst | 10 +++-- docs/contents/pg/connection.rst | 5 ++- pgconn.c | 74 +++++++++++++++++++++++++++----- tests/test_classic_connection.py | 36 ++++++++++++++++ 4 files changed, 110 insertions(+), 15 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 856bd995..30c0b287 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -14,16 +14,18 @@ Version 5.2 (to be released) Additional connection parameter ``nowait``, and connection methods `send_query()`, `poll()`, `set_non_blocking()`, `is_non_blocking()`. Generously contributed by Patrick TJ McPhee (#19). + - The `types` parameter of `format_query` can now be passed as a string + that will be split on whitespace when values are passed as a sequence, + and the types can now also be specified using actual Python types + instead of type names. Suggested by Justin Pryzby (#38). + - The `inserttable()` method now accepts an optional column list that will + be passed on to the COPY command. Contributed by Justin Pryzby (#24). - Changes to the DB-API 2 module (pgdb): - When using Python 2, errors are now derived from StandardError instead of Exception, as required by the DB-API 2 compliance test. - Connection arguments containing single quotes caused problems (reported and fixed by Tyler Ramer and Jamie McAtamney). - - The `types` parameter of `format_query` can now be passed as a string - that will be split on whitespace when values are passed as a sequence, - and the types can now also be specified using actual Python types - instead of type names. Suggested by Justin Pryzby (#38). Version 5.1.2 (2020-04-19) -------------------------- diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index f342c028..c07b6b1b 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -487,12 +487,13 @@ first, otherwise :meth:`Connection.getnotify` will always return ``None``. inserttable -- insert a list into a table ----------------------------------------- -.. method:: Connection.inserttable(table, values) +.. method:: Connection.inserttable(table, values, [columns]) Insert a Python list into a database table :param str table: the table name :param list values: list of rows values + :param list columns: list of column names :rtype: None :raises TypeError: invalid connection, bad argument type, or too many arguments :raises MemoryError: insert buffer could not be allocated @@ -503,6 +504,8 @@ It inserts the whole values list into the given table. Internally, it uses the COPY command of the PostgreSQL database. The list is a list of tuples/lists that define the values for each inserted row. The rows values may contain string, integer, long or double (real) values. +``columns`` is a optional sequence of column names to be passed on +to the COPY command. .. warning:: diff --git a/pgconn.c b/pgconn.c index f8c81e28..34d3cfb8 100644 --- a/pgconn.c +++ b/pgconn.c @@ -686,8 +686,9 @@ conn_is_non_blocking(connObject *self, PyObject *args) /* Insert table */ static char conn_inserttable__doc__[] = -"inserttable(table, data) -- insert list into table\n\n" -"The fields in the list must be in the same order as in the table.\n"; +"inserttable(table, data, [columns]) -- insert list into table\n\n" +"The fields in the list must be in the same order as in the table\n" +"or in the list of columns if one is specified.\n"; static PyObject * conn_inserttable(connObject *self, PyObject *args) @@ -696,10 +697,11 @@ conn_inserttable(connObject *self, PyObject *args) char *table, *buffer, *bufpt; int encoding; size_t bufsiz; - PyObject *list, *sublist, *item; + PyObject *list, *sublist, *item, *columns = NULL; PyObject *(*getitem) (PyObject *, Py_ssize_t); PyObject *(*getsubitem) (PyObject *, Py_ssize_t); - Py_ssize_t i, j, m, n; + PyObject *(*getcolumn) (PyObject *, Py_ssize_t); + Py_ssize_t i, j, m, n = 0; if (!self->cnx) { PyErr_SetString(PyExc_TypeError, "Connection is not valid"); @@ -707,7 +709,7 @@ conn_inserttable(connObject *self, PyObject *args) } /* gets arguments */ - if (!PyArg_ParseTuple(args, "sO:filter", &table, &list)) { + if (!PyArg_ParseTuple(args, "sO|O", &table, &list, &columns)) { PyErr_SetString( PyExc_TypeError, "Method inserttable() expects a string and a list as arguments"); @@ -731,12 +733,68 @@ conn_inserttable(connObject *self, PyObject *args) return NULL; } + /* checks columns type */ + if (columns) { + if (PyList_Check(columns)) { + n = PyList_Size(columns); + getcolumn = PyList_GetItem; + } + else if (PyTuple_Check(columns)) { + n = PyTuple_Size(columns); + getcolumn = PyTuple_GetItem; + } + else { + PyErr_SetString( + PyExc_TypeError, + "Method inserttable() expects a list or a tuple" + " as third argument"); + return NULL; + } + if (!n) { + /* no columns specified, nothing to do */ + Py_INCREF(Py_None); + return Py_None; + } + } + /* allocate buffer */ if (!(buffer = PyMem_Malloc(MAX_BUFFER_SIZE))) return PyErr_NoMemory(); + encoding = PQclientEncoding(self->cnx); + /* starts query */ - sprintf(buffer, "copy %s from stdin", table); + bufpt = buffer; + table = PQescapeIdentifier(self->cnx, table, strlen(table)); + bufpt += sprintf(bufpt, "copy %s", table); + PQfreemem(table); + if (columns) { + /* adds a string like f" ({','.join(columns)})" */ + bufpt += sprintf(bufpt, " ("); + for (int i = 0; i < n; ++i) { + PyObject *obj = getcolumn(columns, i); + ssize_t slen; + char *col; + + if (PyBytes_Check(obj)) { + PyBytes_AsStringAndSize(obj, &col, &slen); + } + else if (PyUnicode_Check(obj)) { + obj = get_encoded_string(obj, encoding); + if (!obj) return NULL; /* pass the UnicodeEncodeError */ + PyBytes_AsStringAndSize(obj, &col, &slen); + Py_DECREF(obj); + } else { + PyErr_SetString( + PyExc_TypeError, + "The third argument must contain only strings"); + } + col = PQescapeIdentifier(self->cnx, col, (size_t) slen); + bufpt += sprintf(bufpt, "%s%s", col, i == n - 1 ? ")" : ","); + PQfreemem(col); + } + } + sprintf(bufpt, " from stdin"); Py_BEGIN_ALLOW_THREADS result = PQexec(self->cnx, buffer); @@ -748,12 +806,8 @@ conn_inserttable(connObject *self, PyObject *args) return NULL; } - encoding = PQclientEncoding(self->cnx); - PQclear(result); - n = 0; /* not strictly necessary but avoids warning */ - /* feed table */ for (i = 0; i < m; ++i) { sublist = getitem(list, i); diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index c1f265ef..40d16432 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1889,6 +1889,42 @@ def testInserttableNullValues(self): self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) + def testInserttableNoColumn(self): + data = [()] * 10 + self.c.inserttable('test', data, []) + self.assertEqual(self.get_back(), []) + + def testInserttableOnlyOneColumn(self): + data = [(42,)] * 50 + self.c.inserttable('test', data, ['i4']) + data = [tuple([42 if i == 1 else None for i in range(14)])] * 50 + self.assertEqual(self.get_back(), data) + + def testInserttableOnlyTwoColumns(self): + data = [(bool(i % 2), i * .5) for i in range(20)] + self.c.inserttable('test', data, ('b', 'f4')) + # noinspection PyTypeChecker + data = [(None,) * 3 + (bool(i % 2),) + (None,) * 3 + (i * .5,) + + (None,) * 6 for i in range(20)] + self.assertEqual(self.get_back(), data) + + def testInserttableWithInvalidTableName(self): + data = [(42,)] + # check that the table name is not inserted unescaped + # (this would pass otherwise since there is a column named i4) + self.assertRaises(Exception, self.c.inserttable, 'test (i4)', data) + # make sure that it works if parameters are passed properly + self.c.inserttable('test', data, ['i4']) + + def testInserttableWithInvalidColumnName(self): + data = [(2, 4)] + # check that the column names are not inserted unescaped + # (this would pass otherwise since there are columns i2 and i4) + self.assertRaises( + Exception, self.c.inserttable, 'test', data, ['i2,i4']) + # make sure that it works if parameters are passed properly + self.c.inserttable('test', data, ['i2', 'i4']) + def testInserttableMaxValues(self): data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), True, '2999-12-31', '11:59:59', 1e99, From 584ca74cb4c8a8546dbe00b5d20f9fee448f5326 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 20 Jun 2020 17:50:50 +0200 Subject: [PATCH 030/194] Add typlen attribute to the DBTypes class --- docs/contents/changelog.rst | 2 ++ docs/contents/pg/connection.rst | 2 +- docs/contents/pg/db_types.rst | 7 ++++--- pg.py | 29 ++++++++++++++++------------- tests/test_classic_dbwrapper.py | 4 ++++ 5 files changed, 27 insertions(+), 17 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 30c0b287..c617eca8 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -20,6 +20,8 @@ Version 5.2 (to be released) instead of type names. Suggested by Justin Pryzby (#38). - The `inserttable()` method now accepts an optional column list that will be passed on to the COPY command. Contributed by Justin Pryzby (#24). + - The `DBTyptes` class now also includes the `typlen` attribute with + information about the size of the type (contributed by Justin Pryzby). - Changes to the DB-API 2 module (pgdb): - When using Python 2, errors are now derived from StandardError diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index c07b6b1b..842822ad 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -504,7 +504,7 @@ It inserts the whole values list into the given table. Internally, it uses the COPY command of the PostgreSQL database. The list is a list of tuples/lists that define the values for each inserted row. The rows values may contain string, integer, long or double (real) values. -``columns`` is a optional sequence of column names to be passed on +``columns`` is an optional sequence of column names to be passed on to the COPY command. .. warning:: diff --git a/docs/contents/pg/db_types.rst b/docs/contents/pg/db_types.rst index 2119ecd3..d7333a41 100644 --- a/docs/contents/pg/db_types.rst +++ b/docs/contents/pg/db_types.rst @@ -13,14 +13,15 @@ returned by :meth:`DB.get_attnames` as dictionary values). These type names are strings which are equal to either the simple PyGreSQL names or to the more fine-grained registered PostgreSQL type names if these -have been enabled with :meth:`DB.use_regtypes`. Besides being strings, they -carry additional information about the associated PostgreSQL type in the -following attributes: +have been enabled with :meth:`DB.use_regtypes`. Type names are strings that +are augmented with additional information about the associated PostgreSQL +type that can be inspected using the following attributes: - *oid* -- the PostgreSQL type OID - *pgtype* -- the internal PostgreSQL data type name - *regtype* -- the registered PostgreSQL data type name - *simple* -- the more coarse-grained PyGreSQL type name + - *typlen* -- internal size of the type, negative if variable - *typtype* -- `b` = base type, `c` = composite type etc. - *category* -- `A` = Array, `b` =Boolean, `C` = Composite etc. - *delim* -- delimiter for array types diff --git a/pg.py b/pg.py index cf996cf0..9e02d7a6 100644 --- a/pg.py +++ b/pg.py @@ -361,7 +361,7 @@ def get_type_dict(self): def _quote_if_unqualified(param, name): """Quote parameter representing a qualified name. - Puts a quote_ident() call around the give parameter unless + Puts a quote_ident() call around the given parameter unless the name contains a dot, in which case the name is ambiguous (could be a qualified name or just a name with a dot in it) and must be quoted manually by the caller. @@ -1212,6 +1212,7 @@ class DbType(str): pgtype: the internal PostgreSQL data type name regtype: the registered PostgreSQL data type name simple: the more coarse-grained PyGreSQL type name + typlen: the internal size, negative if variable typtype: b = base type, c = composite type etc. category: A = Array, b = Boolean, C = Composite etc. delim: delimiter for array types @@ -1245,21 +1246,21 @@ def __init__(self, db): self._typecasts.get_attnames = self.get_attnames self._typecasts.connection = self._db if db.server_version < 80400: - # older remote databases (not officially supported) + # very old remote databases (not officially supported) self._query_pg_type = ( "SELECT oid, typname, typname::text::regtype," - " typtype, null as typcategory, typdelim, typrelid" + " typlen, typtype, null as typcategory, typdelim, typrelid" " FROM pg_catalog.pg_type" " WHERE oid OPERATOR(pg_catalog.=) %s::regtype") else: self._query_pg_type = ( "SELECT oid, typname, typname::regtype," - " typtype, typcategory, typdelim, typrelid" + " typlen, typtype, typcategory, typdelim, typrelid" " FROM pg_catalog.pg_type" " WHERE oid OPERATOR(pg_catalog.=) %s::regtype") def add(self, oid, pgtype, regtype, - typtype, category, delim, relid): + typlen, typtype, category, delim, relid): """Create a PostgreSQL type name with additional info.""" if oid in self: return self[oid] @@ -1269,6 +1270,7 @@ def add(self, oid, pgtype, regtype, typ.simple = simple typ.pgtype = pgtype typ.regtype = regtype + typ.typlen = typlen typ.typtype = typtype typ.category = category typ.delim = delim @@ -1284,7 +1286,7 @@ def __missing__(self, key): except ProgrammingError: res = None if not res: - raise KeyError('Type %s could not be found' % key) + raise KeyError('Type %s could not be found' % (key,)) res = res[0] typ = self.add(*res) self[typ.oid] = self[typ.pgtype] = typ @@ -1610,24 +1612,25 @@ def __init__(self, *args, **kw): self.adapter = Adapter(self) self.dbtypes = DbTypes(self) if db.server_version < 80400: - # support older remote data bases (not officially supported) + # very old remote databases (not officially supported) self._query_attnames = ( "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype," - " t.typtype, null as typcategory, t.typdelim, t.typrelid" + " t.typlen, t.typtype, null as typcategory," + " t.typdelim, t.typrelid" " FROM pg_catalog.pg_attribute a" " JOIN pg_catalog.pg_type t" " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass AND %s" - " AND NOT a.attisdropped ORDER BY a.attnum") + " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass" + " AND %s AND NOT a.attisdropped ORDER BY a.attnum") else: self._query_attnames = ( "SELECT a.attname, t.oid, t.typname, t.typname::regtype," - " t.typtype, t.typcategory, t.typdelim, t.typrelid" + " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" " FROM pg_catalog.pg_attribute a" " JOIN pg_catalog.pg_type t" " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass AND %s" - " AND NOT a.attisdropped ORDER BY a.attnum") + " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass" + " AND %s AND NOT a.attisdropped ORDER BY a.attnum") db.set_cast_hook(self.dbtypes.typecast) # For debugging scripts, self.debug can be set # * to a string format specification (e.g. in CGI set to "%s
"), diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 8704fbe5..ee9a64fe 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4062,6 +4062,7 @@ def testDbTypesInfo(self): self.assertEqual(typ.pgtype, 'numeric') self.assertEqual(typ.regtype, 'numeric') self.assertEqual(typ.simple, 'num') + self.assertEqual(typ.typlen, -1) self.assertEqual(typ.typtype, 'b') self.assertEqual(typ.category, 'N') self.assertEqual(typ.delim, ',') @@ -4075,6 +4076,7 @@ def testDbTypesInfo(self): self.assertEqual(typ.pgtype, 'pg_type') self.assertEqual(typ.regtype, 'pg_type') self.assertEqual(typ.simple, 'record') + self.assertEqual(typ.typlen, -1) self.assertEqual(typ.typtype, 'c') self.assertEqual(typ.category, 'C') self.assertEqual(typ.delim, ',') @@ -4086,11 +4088,13 @@ def testDbTypesInfo(self): self.assertIn('typname', attnames) typname = attnames['typname'] self.assertEqual(typname, 'name' if self.regtypes else 'text') + self.assertEqual(typname.typlen, 64) # base self.assertEqual(typname.typtype, 'b') # base self.assertEqual(typname.category, 'S') # string self.assertIn('typlen', attnames) typlen = attnames['typlen'] self.assertEqual(typlen, 'smallint' if self.regtypes else 'int') + self.assertEqual(typlen.typlen, 2) # base self.assertEqual(typlen.typtype, 'b') # base self.assertEqual(typlen.category, 'N') # numeric From d06b02bb15662e43acae179f77bb1701cda29fb6 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 20 Jun 2020 21:11:36 +0200 Subject: [PATCH 031/194] Add new query method fieldinfo(field) (#29) --- docs/conf.py | 4 +- docs/contents/changelog.rst | 4 +- docs/contents/pg/query.rst | 32 ++++++++++++-- pg.py | 2 +- pgquery.c | 73 ++++++++++++++++++++++++++++++++ tests/test_classic_connection.py | 30 ++++++++++++- 6 files changed, 137 insertions(+), 8 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 3eae7b58..c20b8ab2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -112,8 +112,8 @@ 'Notice', 'DATETIME'), 'data': ('defbase', 'defhost', 'defopt', 'defpasswd', 'defport', 'defuser'), - 'exc': ('Exception', 'IOError', 'KeyError', 'MemoryError', - 'SyntaxError', 'TypeError', 'ValueError', + 'exc': ('Exception', 'IndexError', 'IOError', 'KeyError', + 'MemoryError', 'SyntaxError', 'TypeError', 'ValueError', 'pg.InternalError', 'pg.InvalidResultError', 'pg.MultipleResultsError', 'pg.NoResultError', 'pg.OperationalError', 'pg.ProgrammingError'), diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index c617eca8..edd822e4 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -10,6 +10,8 @@ Version 5.2 (to be released) of the pqlib used by PyGreSQL (needs PostgreSQL >= 9.1 on the client). - New query method `memsize()` that gets the memory size allocated by the query (needs PostgreSQL >= 12 on the client). + - New query method `fieldinfo()` that gets name and type information for + one or all field(s) of the query. Contributed by Justin Pryzby (#39). - Experimental support for asynchronous command processing. Additional connection parameter ``nowait``, and connection methods `send_query()`, `poll()`, `set_non_blocking()`, `is_non_blocking()`. @@ -20,7 +22,7 @@ Version 5.2 (to be released) instead of type names. Suggested by Justin Pryzby (#38). - The `inserttable()` method now accepts an optional column list that will be passed on to the COPY command. Contributed by Justin Pryzby (#24). - - The `DBTyptes` class now also includes the `typlen` attribute with + - The `DBTypes` class now also includes the `typlen` attribute with information about the size of the type (contributed by Justin Pryzby). - Changes to the DB-API 2 module (pgdb): diff --git a/docs/contents/pg/query.rst b/docs/contents/pg/query.rst index 4fcf46fb..756e69ed 100644 --- a/docs/contents/pg/query.rst +++ b/docs/contents/pg/query.rst @@ -328,12 +328,12 @@ is empty and of type :exc:`pg.MultipleResultsError` if it has multiple rows. .. versionadded:: 5.1 -listfields -- list fields names of previous query result --------------------------------------------------------- +listfields -- list field names of previous query result +------------------------------------------------------- .. method:: Query.listfields() - List fields names of previous query result + List field names of previous query result :returns: field names :rtype: list @@ -374,6 +374,32 @@ build a function that converts result list strings to their correct type, using a hardcoded table definition. The number returned is the field rank in the query result. +fieldinfo -- detailed info about fields of previous query result +---------------------------------------------------------------- + +.. method:: Query.fieldinfo([field]) + + Get information on one or all fields of the last query + + :param field: a column number or name (optional) + :type field: int or str + :returns: field info tuple(s) for all fields or given field + :rtype: tuple + :raises IndexError: field does not exist + :raises TypeError: too many parameters + +If the ``field`` is specified by passing either a column number or a field +name, a four-tuple with information for the specified field of the previous +query result will be returned. If no ``field`` is specified, a tuple of +four-tuples for every field of the previous query result will be returned, +in the order as they appear in the query result. + +The four-tuples contain the following information: The field name, the +internal OID number of the field type, the size in bytes of the column or a +negative value if it is of variable size, and a type-specific modifier value. + +.. versionadded:: 5.2 + ntuples -- return number of tuples in query object -------------------------------------------------- diff --git a/pg.py b/pg.py index 9e02d7a6..bc447c71 100644 --- a/pg.py +++ b/pg.py @@ -191,7 +191,7 @@ def wrapper(arg): return decorator -# Auxiliary classes and functions that are independent from a DB connection: +# Auxiliary classes and functions that are independent of a DB connection: try: # noinspection PyUnresolvedReferences from inspect import signature diff --git a/pgquery.c b/pgquery.c index fefd8e70..335d6813 100644 --- a/pgquery.c +++ b/pgquery.c @@ -340,6 +340,77 @@ query_fieldnum(queryObject *self, PyObject *args) return PyInt_FromLong(num); } +/* Build a tuple with info for query field with given number. */ +static PyObject * +_query_build_field_info(PGresult *res, int col_num) { + PyObject *info; + + info = PyTuple_New(4); + if (info) { + PyTuple_SET_ITEM(info, 0, PyStr_FromString(PQfname(res, col_num))); + PyTuple_SET_ITEM(info, 1, PyInt_FromLong(PQftype(res, col_num))); + PyTuple_SET_ITEM(info, 2, PyInt_FromLong(PQfsize(res, col_num))); + PyTuple_SET_ITEM(info, 3, PyInt_FromLong(PQfmod(res, col_num))); + } + return info; +} + +/* Get information on one or all fields in last result. */ +static char query_fieldinfo__doc__[] = +"fieldinfo() -- return info on field(s) in query"; + +static PyObject * +query_fieldinfo(queryObject *self, PyObject *args) +{ + PyObject *result, *field = NULL; + int num; + + /* gets args */ + if (!PyArg_ParseTuple(args, "|O", &field)) { + PyErr_SetString(PyExc_TypeError, + "Method fieldinfo() takes one optional argument only"); + return NULL; + } + + /* check optional field arg */ + if (field) { + /* gets field number */ + if (PyBytes_Check(field)) { + num = PQfnumber(self->result, PyBytes_AsString(field)); + } else if (PyStr_Check(field)) { + PyObject *tmp = get_encoded_string(field, self->encoding); + if (!tmp) return NULL; + num = PQfnumber(self->result, PyBytes_AsString(tmp)); + Py_DECREF(tmp); + } else if (PyInt_Check(field)) { + num = (int) PyInt_AsLong(field); + } else { + PyErr_SetString(PyExc_TypeError, + "Field should be given as column number or name"); + return NULL; + } + if (num < 0 || num >= self->num_fields) { + PyErr_SetString(PyExc_IndexError, "Unknown field"); + return NULL; + } + return _query_build_field_info(self->result, num); + } + + if (!(result = PyTuple_New(self->num_fields))) { + return NULL; + } + for (num = 0; num < self->num_fields; ++num) { + PyObject *info = _query_build_field_info(self->result, num); + if (!info) { + Py_DECREF(result); + return NULL; + } + PyTuple_SET_ITEM(result, num, info); + } + return result; +} + + /* Retrieve one row from the result as a tuple. */ static char query_one__doc__[] = "one() -- Get one row from the result of a query\n\n" @@ -869,6 +940,8 @@ static struct PyMethodDef query_methods[] = { METH_VARARGS, query_fieldnum__doc__}, {"listfields", (PyCFunction) query_listfields, METH_NOARGS, query_listfields__doc__}, + {"fieldinfo", (PyCFunction) query_fieldinfo, + METH_VARARGS, query_fieldinfo__doc__}, {"ntuples", (PyCFunction) query_ntuples, METH_NOARGS, query_ntuples__doc__}, #ifdef MEMORY_SIZE diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 40d16432..c7669d61 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -287,7 +287,7 @@ def testMethodSendQueryEmpty(self): def testAllQueryMembers(self): query = self.connection.query("select true where false") members = ''' - dictiter dictresult fieldname fieldnum getresult + dictiter dictresult fieldinfo fieldname fieldnum getresult listfields memsize namediter namedresult ntuples one onedict onenamed onescalar scalariter scalarresult single singledict singlenamed singlescalar @@ -694,6 +694,34 @@ def testFieldnum(self): self.assertIsInstance(r, int) self.assertEqual(r, 3) + def testFieldInfoName(self): + q = ('select true as FooBar, 42::smallint as "FooBar",' + ' 4.2::numeric(4,2) as foo_bar, \'baz\'::char(3) as "Foo Bar"') + f = self.c.query(q).fieldinfo + result = (('foobar', 16, 1, -1), ('FooBar', 21, 2, -1), + ('foo_bar', 1700, -1, ((4 << 16) | 2) + 4), + ('Foo Bar', 1042, -1, 3 + 4)) + r = f() + self.assertIsInstance(r, tuple) + self.assertEqual(len(r), 4) + self.assertEqual(r, result) + for field_num, info in enumerate(result): + field_name = info[0] + if field_num > 0: + field_name = '"%s"' % field_name + r = f(field_name) + self.assertIsInstance(r, tuple) + self.assertEqual(len(r), 4) + self.assertEqual(r, info) + r = f(field_num) + self.assertIsInstance(r, tuple) + self.assertEqual(len(r), 4) + self.assertEqual(r, info) + self.assertRaises(IndexError, f, 'foobaz') + self.assertRaises(IndexError, f, '"Foobar"') + self.assertRaises(IndexError, f, -1) + self.assertRaises(IndexError, f, 4) + def testNtuples(self): # deprecated q = "select 1 where false" r = self.c.query(q).ntuples() From 28c642b90c17ca88b2821bb0c0913ae5c3bb7950 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 20 Jun 2020 23:40:43 +0200 Subject: [PATCH 032/194] Do not close large objects when deallocating (#30) If a large object goes out of scope and is deallocated after a transaction is closed, it will have been already closed on the server side and you will get an error message on the server when the dealloc method tries to close it again. In other cases, the object might be closed too early when the Python object falls out of scope. Therefore it is best to not close it here. --- pglarge.c | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pglarge.c b/pglarge.c index 4f9e0f3e..590f3fbb 100644 --- a/pglarge.c +++ b/pglarge.c @@ -12,8 +12,12 @@ static void large_dealloc(largeObject *self) { - if (self->lo_fd >= 0 && self->pgcnx->valid) - lo_close(self->pgcnx->cnx, self->lo_fd); + /* Note: We do not try to close the large object here anymore, + since the server automatically closes it at the end of the + transaction in which it was created. So the object might already + be closed, which will then cause error messages on the server. + In other situations we might close the object too early here + if the Python object falls out of scope but is still needed. */ Py_XDECREF(self->pgcnx); PyObject_Del(self); From 77098e11f0c0041b53dcc6512d2e5a4a6805905b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Jun 2020 00:09:19 +0200 Subject: [PATCH 033/194] Minor changes in the large objects documentation (#30) --- docs/contents/changelog.rst | 3 ++ docs/contents/pg/large_objects.rst | 71 ++++++++++++++++-------------- 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index edd822e4..8139b27e 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -24,6 +24,9 @@ Version 5.2 (to be released) be passed on to the COPY command. Contributed by Justin Pryzby (#24). - The `DBTypes` class now also includes the `typlen` attribute with information about the size of the type (contributed by Justin Pryzby). + - Large objects on the server are not closed any more when they are + deallocated as Python objects, since this could cause several problems. + Bug report and analysis by Justin Pryzby (#30). - Changes to the DB-API 2 module (pgdb): - When using Python 2, errors are now derived from StandardError diff --git a/docs/contents/pg/large_objects.rst b/docs/contents/pg/large_objects.rst index 3efa5d3b..0d48fd84 100644 --- a/docs/contents/pg/large_objects.rst +++ b/docs/contents/pg/large_objects.rst @@ -5,22 +5,23 @@ LargeObject -- Large Objects .. class:: LargeObject -Objects that are instances of the class :class:`LargeObject` are used to handle -all the requests concerning a PostgreSQL large object. These objects embed -and hide all the "recurrent" variables (object OID and connection), exactly -in the same way :class:`Connection` instances do, thus only keeping significant -parameters in function calls. The :class:`LargeObject` instance keeps a -reference to the :class:`Connection` object used for its creation, sending -requests though with its parameters. Any modification but dereferencing the +Instances of the class :class:`LargeObject` are used to handle all the +requests concerning a PostgreSQL large object. These objects embed and hide +all the recurring variables (object OID and connection), in the same way +:class:`Connection` instances do, thus only keeping significant parameters +in function calls. The :class:`LargeObject` instance keeps a reference to +the :class:`Connection` object used for its creation, sending requests +though with its parameters. Any modification other than dereferencing the :class:`Connection` object will thus affect the :class:`LargeObject` instance. Dereferencing the initial :class:`Connection` object is not a problem since Python won't deallocate it before the :class:`LargeObject` instance -dereferences it. All functions return a generic error message on call error, -whatever the exact error was. The :attr:`error` attribute of the object allows -to get the exact error message. +dereferences it. All functions return a generic error message on error. +The exact error message is provided by the object's :attr:`error` attribute. -See also the PostgreSQL programmer's guide for more information about the -large object interface. +See also the PostgreSQL documentation for more information about the +`large object interface`__. + +__ https://www.postgresql.org/docs/current/largeobjects.html open -- open a large object --------------------------- @@ -34,9 +35,10 @@ open -- open a large object :raises TypeError: invalid connection, bad parameter type, or too many parameters :raises IOError: already opened object, or open error -This method opens a large object for reading/writing, in the same way than the -Unix open() function. The mode value can be obtained by OR-ing the constants -defined in the :mod:`pg` module (:const:`INV_READ`, :const:`INV_WRITE`). +This method opens a large object for reading/writing, in a similar manner as +the Unix open() function does for files. The mode value can be obtained by +OR-ing the constants defined in the :mod:`pg` module (:const:`INV_READ`, +:const:`INV_WRITE`). close -- close a large object ----------------------------- @@ -50,7 +52,7 @@ close -- close a large object :raises TypeError: too many parameters :raises IOError: object is not opened, or close error -This method closes a previously opened large object, in the same way than +This method closes a previously opened large object, in a similar manner as the Unix close() function. read, write, tell, seek, unlink -- file-like large object handling @@ -60,7 +62,7 @@ read, write, tell, seek, unlink -- file-like large object handling Read data from large object - :param int size: maximal size of the buffer to be read + :param int size: maximum size of the buffer to be read :returns: the read buffer :rtype: bytes :raises TypeError: invalid connection, invalid object, @@ -68,8 +70,8 @@ read, write, tell, seek, unlink -- file-like large object handling :raises ValueError: if `size` is negative :raises IOError: object is not opened, or read error -This function allows to read data from a large object, starting at current -position. +This function allows reading data from a large object, starting at the +current position. .. method:: LargeObject.write(string) @@ -80,7 +82,7 @@ position. :raises TypeError: invalid connection, bad parameter type, or too many parameters :raises IOError: object is not opened, or write error -This function allows to write data to a large object, starting at current +This function allows writing data to a large object, starting at the current position. .. method:: LargeObject.seek(offset, whence) @@ -95,9 +97,9 @@ position. bad parameter type, or too many parameters :raises IOError: object is not opened, or seek error -This method allows to move the position cursor in the large object. -The valid values for the whence parameter are defined as constants in the -:mod:`pg` module (:const:`SEEK_SET`, :const:`SEEK_CUR`, :const:`SEEK_END`). +This method updates the position offset in the large object. The valid values +for the whence parameter are defined as constants in the :mod:`pg` module +(:const:`SEEK_SET`, :const:`SEEK_CUR`, :const:`SEEK_END`). .. method:: LargeObject.tell() @@ -109,7 +111,7 @@ The valid values for the whence parameter are defined as constants in the :raises TypeError: too many parameters :raises IOError: object is not opened, or seek error -This method allows to get the current position in the large object. +This method returns the current position offset in the large object. .. method:: LargeObject.unlink() @@ -135,7 +137,7 @@ size -- get the large object size :raises TypeError: too many parameters :raises IOError: object is not opened, or seek/tell error -This (composite) method allows to get the size of a large object. It was +This (composite) method returns the size of a large object. It was implemented because this function is very useful for a web interfaced database. Currently, the large object needs to be opened first. @@ -152,14 +154,14 @@ export -- save a large object to a file bad parameter type, or too many parameters :raises IOError: object is not closed, or export error -This methods allows to dump the content of a large object in a very simple -way. The exported file is created on the host of the program, not the -server host. +This methods allows saving the content of a large object to a file in a +very simple way. The file is created on the host running the PyGreSQL +interface, not on the server host. Object attributes ----------------- -:class:`LargeObject` objects define a read-only set of attributes that allow -to get some information about it. These attributes are: +:class:`LargeObject` objects define a read-only set of attributes exposing +some information about it. These attributes are: .. attribute:: LargeObject.oid @@ -175,9 +177,10 @@ to get some information about it. These attributes are: .. warning:: - In multi-threaded environments, :attr:`LargeObject.error` may be modified by - another thread using the same :class:`Connection`. Remember these object - are shared, not duplicated. You should provide some locking to be able - if you want to check this. The :attr:`LargeObject.oid` attribute is very + In multi-threaded environments, :attr:`LargeObject.error` may be modified + by another thread using the same :class:`Connection`. Remember these + objects are shared, not duplicated. You should provide some locking if you + want to use this information in a program in which it's shared between + multiple threads. The :attr:`LargeObject.oid` attribute is very interesting, because it allows you to reuse the OID later, creating the :class:`LargeObject` object with a :meth:`Connection.getlo` method call. From f240d5fa38b3b49ab25e43c656dae9b560ec2e64 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Jun 2020 00:36:43 +0200 Subject: [PATCH 034/194] Use consistent formatting in the changelog --- docs/contents/changelog.rst | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 8139b27e..9f22a7b2 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,8 +1,8 @@ ChangeLog ========= -Version 5.2 (to be released) ----------------------------- +Version 5.2 (2020-06-22) +------------------------ - We now require Python version 2.7 or 3.5 and newer. - All Python code is now tested with flake8 and made PEP8 compliant. - Changes to the classic PyGreSQL module (pg): @@ -13,7 +13,7 @@ Version 5.2 (to be released) - New query method `fieldinfo()` that gets name and type information for one or all field(s) of the query. Contributed by Justin Pryzby (#39). - Experimental support for asynchronous command processing. - Additional connection parameter ``nowait``, and connection methods + Additional connection parameter `nowait`, and connection methods `send_query()`, `poll()`, `set_non_blocking()`, `is_non_blocking()`. Generously contributed by Patrick TJ McPhee (#19). - The `types` parameter of `format_query` can now be passed as a string @@ -27,7 +27,6 @@ Version 5.2 (to be released) - Large objects on the server are not closed any more when they are deallocated as Python objects, since this could cause several problems. Bug report and analysis by Justin Pryzby (#30). - - Changes to the DB-API 2 module (pgdb): - When using Python 2, errors are now derived from StandardError instead of Exception, as required by the DB-API 2 compliance test. @@ -82,7 +81,6 @@ Version 5.1 (2019-05-17) and this function is not part of the official API. - Added new connection attributes `socket`, `backend_pid`, `ssl_in_use` and `ssl_attributes` (the latter need PostgreSQL >= 9.5 on the client). - - Changes to the DB-API 2 module (pgdb): - Connections now have an `autocommit` attribute which is set to `False` by default but can be set to `True` to switch to autocommit mode where From f3d57a486e710d5eca03ccb49f93758afcc3a6d1 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Jun 2020 14:15:32 +0200 Subject: [PATCH 035/194] Minor change of wording in docs --- docs/contents/pg/large_objects.rst | 2 +- docs/contents/pg/query.rst | 24 ++++++++++++------------ pgquery.c | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/contents/pg/large_objects.rst b/docs/contents/pg/large_objects.rst index 0d48fd84..a1d9818d 100644 --- a/docs/contents/pg/large_objects.rst +++ b/docs/contents/pg/large_objects.rst @@ -11,7 +11,7 @@ all the recurring variables (object OID and connection), in the same way :class:`Connection` instances do, thus only keeping significant parameters in function calls. The :class:`LargeObject` instance keeps a reference to the :class:`Connection` object used for its creation, sending requests -though with its parameters. Any modification other than dereferencing the +through with its parameters. Any modification other than dereferencing the :class:`Connection` object will thus affect the :class:`LargeObject` instance. Dereferencing the initial :class:`Connection` object is not a problem since Python won't deallocate it before the :class:`LargeObject` instance diff --git a/docs/contents/pg/query.rst b/docs/contents/pg/query.rst index 756e69ed..9e2998f8 100644 --- a/docs/contents/pg/query.rst +++ b/docs/contents/pg/query.rst @@ -328,19 +328,19 @@ is empty and of type :exc:`pg.MultipleResultsError` if it has multiple rows. .. versionadded:: 5.1 -listfields -- list field names of previous query result -------------------------------------------------------- +listfields -- list field names of query result +---------------------------------------------- .. method:: Query.listfields() - List field names of previous query result + List field names of query result :returns: field names :rtype: list :raises TypeError: too many parameters -This method returns the list of field names defined for the -query result. The fields are in the same order as the result values. +This method returns the list of field names defined for the query result. +The fields are in the same order as the result values. fieldname, fieldnum -- field name/number conversion --------------------------------------------------- @@ -374,12 +374,12 @@ build a function that converts result list strings to their correct type, using a hardcoded table definition. The number returned is the field rank in the query result. -fieldinfo -- detailed info about fields of previous query result ----------------------------------------------------------------- +fieldinfo -- detailed info about query result fields +---------------------------------------------------- .. method:: Query.fieldinfo([field]) - Get information on one or all fields of the last query + Get information on one or all fields of the query :param field: a column number or name (optional) :type field: int or str @@ -389,10 +389,10 @@ fieldinfo -- detailed info about fields of previous query result :raises TypeError: too many parameters If the ``field`` is specified by passing either a column number or a field -name, a four-tuple with information for the specified field of the previous -query result will be returned. If no ``field`` is specified, a tuple of -four-tuples for every field of the previous query result will be returned, -in the order as they appear in the query result. +name, a four-tuple with information for the specified field of the query +result will be returned. If no ``field`` is specified, a tuple of four-tuples +for every field of the previous query result will be returned, in the same +order as they appear in the query result. The four-tuples contain the following information: The field name, the internal OID number of the field type, the size in bytes of the column or a diff --git a/pgquery.c b/pgquery.c index 335d6813..0382cf37 100644 --- a/pgquery.c +++ b/pgquery.c @@ -355,7 +355,7 @@ _query_build_field_info(PGresult *res, int col_num) { return info; } -/* Get information on one or all fields in last result. */ +/* Get information on one or all fields of the query result. */ static char query_fieldinfo__doc__[] = "fieldinfo() -- return info on field(s) in query"; From 3e4ce801dc4b9a60027283c923a4f3fbea9b4bc0 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Jun 2020 14:32:06 +0200 Subject: [PATCH 036/194] Check for proper Python versions in setup.py --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index a95b0a2b..7ff1843a 100755 --- a/setup.py +++ b/setup.py @@ -54,8 +54,8 @@ version = '5.2' -if (not (2, 6) <= sys.version_info[:2] < (3, 0) - and not (3, 3) <= sys.version_info[:2] < (4, 0)): +if not (sys.version_info[:2] == (2, 7) + or (3, 5) <= sys.version_info[:2] < (4, 0)): raise Exception( "Sorry, PyGreSQL %s does not support this Python version" % version) From b7148ab9bee68e7516479cb043cfc6fbbbf874f7 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Jun 2020 14:36:07 +0200 Subject: [PATCH 037/194] Change release date --- docs/contents/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 9f22a7b2..e93bd56c 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,7 +1,7 @@ ChangeLog ========= -Version 5.2 (2020-06-22) +Version 5.2 (2020-06-21) ------------------------ - We now require Python version 2.7 or 3.5 and newer. - All Python code is now tested with flake8 and made PEP8 compliant. From f96cee91cfe5e5ca93770e2eb84ebe05eef41f4d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Jun 2020 16:39:23 +0200 Subject: [PATCH 038/194] Minor fix of index variable in inserttable method --- pgconn.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pgconn.c b/pgconn.c index 34d3cfb8..777a0bef 100644 --- a/pgconn.c +++ b/pgconn.c @@ -771,8 +771,8 @@ conn_inserttable(connObject *self, PyObject *args) if (columns) { /* adds a string like f" ({','.join(columns)})" */ bufpt += sprintf(bufpt, " ("); - for (int i = 0; i < n; ++i) { - PyObject *obj = getcolumn(columns, i); + for (j = 0; j < n; ++j) { + PyObject *obj = getcolumn(columns, j); ssize_t slen; char *col; @@ -790,7 +790,7 @@ conn_inserttable(connObject *self, PyObject *args) "The third argument must contain only strings"); } col = PQescapeIdentifier(self->cnx, col, (size_t) slen); - bufpt += sprintf(bufpt, "%s%s", col, i == n - 1 ? ")" : ","); + bufpt += sprintf(bufpt, "%s%s", col, j == n - 1 ? ")" : ","); PQfreemem(col); } } From 18a1ceb2f2088fc73780306dd2c2e8ae7ab8bd41 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Jun 2020 17:15:59 +0200 Subject: [PATCH 039/194] Remove compiler warnings for 32-bit Linux --- pgconn.c | 2 +- pglarge.c | 2 +- pgquery.c | 2 +- pgsource.c | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pgconn.c b/pgconn.c index 777a0bef..ecf18a76 100644 --- a/pgconn.c +++ b/pgconn.c @@ -193,7 +193,7 @@ _conn_non_query_result(int status, PGresult* result, PGconn *cnx) } /* for a single insert, return the oid */ PQclear(result); - return PyInt_FromLong(oid); + return PyInt_FromLong((long) oid); } case PGRES_COPY_OUT: /* no data will be received */ case PGRES_COPY_IN: diff --git a/pglarge.c b/pglarge.c index 590f3fbb..f2c3a63e 100644 --- a/pglarge.c +++ b/pglarge.c @@ -85,7 +85,7 @@ large_getattr(largeObject *self, PyObject *nameobj) /* large object oid */ if (!strcmp(name, "oid")) { if (_check_lo_obj(self, 0)) - return PyInt_FromLong(self->lo_oid); + return PyInt_FromLong((long) self->lo_oid); PyErr_Clear(); Py_INCREF(Py_None); return Py_None; diff --git a/pgquery.c b/pgquery.c index 0382cf37..50f837bb 100644 --- a/pgquery.c +++ b/pgquery.c @@ -348,7 +348,7 @@ _query_build_field_info(PGresult *res, int col_num) { info = PyTuple_New(4); if (info) { PyTuple_SET_ITEM(info, 0, PyStr_FromString(PQfname(res, col_num))); - PyTuple_SET_ITEM(info, 1, PyInt_FromLong(PQftype(res, col_num))); + PyTuple_SET_ITEM(info, 1, PyInt_FromLong((long) PQftype(res, col_num))); PyTuple_SET_ITEM(info, 2, PyInt_FromLong(PQfsize(res, col_num))); PyTuple_SET_ITEM(info, 3, PyInt_FromLong(PQfmod(res, col_num))); } diff --git a/pgsource.c b/pgsource.c index 9ab94e36..2311e2a0 100644 --- a/pgsource.c +++ b/pgsource.c @@ -272,7 +272,7 @@ source_oidstatus(sourceObject *self, PyObject *noargs) return Py_None; } - return PyInt_FromLong(oid); + return PyInt_FromLong((long) oid); } /* Fetch rows from last result. */ @@ -671,7 +671,7 @@ _source_buildinfo(sourceObject *self, int num) PyTuple_SET_ITEM(result, 1, PyStr_FromString(PQfname(self->result, num))); PyTuple_SET_ITEM(result, 2, - PyInt_FromLong(PQftype(self->result, num))); + PyInt_FromLong((long) PQftype(self->result, num))); PyTuple_SET_ITEM(result, 3, PyInt_FromLong(PQfsize(self->result, num))); PyTuple_SET_ITEM(result, 4, From 17914b61ef8460ccdd425e9c1666c046a691f93e Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 12 Jul 2020 14:39:23 +0200 Subject: [PATCH 040/194] Support qualified table names in copy_from/to (#47) --- docs/contents/changelog.rst | 6 ++++++ pgdb.py | 16 ++++++++++------ tests/test_dbapi20_copy.py | 8 ++++++++ 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index e93bd56c..180710e9 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,12 @@ ChangeLog ========= +Version 5.2.1 (to be released) +------------------------------ +- Changes to the DB-API 2 module (pgdb): + - The `copy_to()` and `copy_from()` methods now also work with table names + containing schema qualifiers (#47). + Version 5.2 (2020-06-21) ------------------------ - We now require Python version 2.7 or 3.5 and newer. diff --git a/pgdb.py b/pgdb.py index ec94ceb5..976f3bdd 100644 --- a/pgdb.py +++ b/pgdb.py @@ -1268,10 +1268,11 @@ def chunks(): if not table or not isinstance(table, basestring): raise TypeError("Need a table to copy to") - if table.lower().startswith('select'): + if table.lower().startswith('select '): raise ValueError("Must specify a table, not a query") else: - table = '"%s"' % (table,) + table = '.'.join(map( + self.connection._cnx.escape_identifier, table.split('.', 1))) operation = ['copy %s' % (table,)] options = [] params = [] @@ -1299,7 +1300,8 @@ def chunks(): params.append(null) if columns: if not isinstance(columns, basestring): - columns = ','.join('"%s"' % (col,) for col in columns) + columns = ','.join(map( + self.connection._cnx.escape_identifier, columns)) operation.append('(%s)' % (columns,)) operation.append("from stdin") if options: @@ -1350,12 +1352,13 @@ def copy_to(self, stream, table, raise TypeError("Need an output stream to copy to") if not table or not isinstance(table, basestring): raise TypeError("Need a table to copy to") - if table.lower().startswith('select'): + if table.lower().startswith('select '): if columns: raise ValueError("Columns must be specified in the query") table = '(%s)' % (table,) else: - table = '"%s"' % (table,) + table = '.'.join(map( + self.connection._cnx.escape_identifier, table.split('.', 1))) operation = ['copy %s' % (table,)] options = [] params = [] @@ -1394,7 +1397,8 @@ def copy_to(self, stream, table, "The decode option is not allowed with binary format") if columns: if not isinstance(columns, basestring): - columns = ','.join('"%s"' % (col,) for col in columns) + columns = ','.join(map( + self.connection._cnx.escape_identifier, columns)) operation.append('(%s)' % (columns,)) operation.append("to stdout") diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 939a6828..dbe25fd1 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -253,6 +253,10 @@ def test_input_string(self): self.assertEqual(self.table_data, [(42, 'Hello, world!')]) self.check_rowcount(1) + def test_input_string_with_schema_name(self): + self.cursor.copy_from('42\tHello, world!', 'public.copytest') + self.assertEqual(self.table_data, [(42, 'Hello, world!')]) + def test_input_string_with_newline(self): self.copy_from('42\tHello, world!\n') self.assertEqual(self.table_data, [(42, 'Hello, world!')]) @@ -449,6 +453,10 @@ def test_generator(self): self.assertEqual(rows, self.data_text) self.check_rowcount() + def test_generator_with_schema_name(self): + ret = self.cursor.copy_to(None, 'public.copytest') + self.assertEqual(''.join(ret), self.data_text) + if str is unicode: # Python >= 3.0 def test_generator_bytes(self): From 165394eaf95dc3c900ea77f3ae50b78dc2298566 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 22 Aug 2020 12:35:37 +0200 Subject: [PATCH 041/194] You cannot use getresult() with an update query --- docs/contents/pg/connection.rst | 2 +- docs/contents/pg/db_wrapper.rst | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 842822ad..7d721b89 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -164,7 +164,7 @@ Examples:: con1 = connect() con2 = connect() s = con1.query("begin; set transaction isolation level repeatable read;" - "select pg_export_snapshot();").getresult()[0][0] + "select pg_export_snapshot();").single() con2.query("begin; set transaction isolation level repeatable read;" "set transaction snapshot '%s'" % (s,)) q1 = con1.send_query("select a,b,c from x where d=e") diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 92ef1ad9..d2ef4e05 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -455,11 +455,11 @@ Example:: name = input("Name? ") phone = input("Phone? ") - rows = db.query("update employees set phone=$2 where name=$1", - name, phone).getresult()[0][0] + num_rows = db.query("update employees set phone=$2 where name=$1", + name, phone) # or - rows = db.query("update employees set phone=$2 where name=$1", - (name, phone)).getresult()[0][0] + num_rows = db.query("update employees set phone=$2 where name=$1", + (name, phone)) query_formatted -- execute a formatted SQL command string --------------------------------------------------------- @@ -506,13 +506,13 @@ Example:: name = input("Name? ") phone = input("Phone? ") - rows = db.query_formatted( + num_rows = db.query_formatted( "update employees set phone=%s where name=%s", - (phone, name)).getresult()[0][0] + (phone, name)) # or - rows = db.query_formatted( + num_rows = db.query_formatted( "update employees set phone=%(phone)s where name=%(name)s", - dict(name=name, phone=phone)).getresult()[0][0] + dict(name=name, phone=phone)) Example with specification of types:: From 4023dd5c274997fd04cba99a0832baaa58b43d3a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 22 Aug 2020 12:36:03 +0200 Subject: [PATCH 042/194] Minor simplification --- pg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pg.py b/pg.py index bc447c71..f1bc62be 100644 --- a/pg.py +++ b/pg.py @@ -1863,7 +1863,7 @@ def get_parameter(self, parameter): else: for param in params: q = 'SHOW %s' % (param,) - value = self.db.query(q).getresult()[0][0] + value = self.db.query(q).singlescalar() if values is None: values = value elif isinstance(values, list): @@ -2192,7 +2192,7 @@ def has_table_privilege(self, table, privilege='select', flush=False): q = "SELECT pg_catalog.has_table_privilege(%s, $2)" % ( _quote_if_unqualified('$1', table),) q = self.db.query(q, (table, privilege)) - ret = q.getresult()[0][0] == self._make_bool(True) + ret = q.singlescalar() == self._make_bool(True) privileges[table, privilege] = ret # cache it return ret From 80ded9808439426bb428cd3f8dd5fb9116d89c5f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 25 Sep 2020 13:00:04 +0200 Subject: [PATCH 043/194] Bump version number, support Py 3.9 and Pg 13 --- .bumpversion.cfg | 2 +- .travis.yml | 6 +++++- docs/about.txt | 4 ++-- docs/announce.rst | 12 ++++++------ docs/conf.py | 2 +- docs/contents/changelog.rst | 1 + docs/contents/install.rst | 2 +- docs/contents/pg/adaptation.rst | 2 +- docs/requirements.txt | 3 ++- setup.py | 9 +++++---- tests/test_classic_connection.py | 2 +- tests/test_classic_dbwrapper.py | 2 +- tests/test_classic_functions.py | 2 +- tox.ini | 4 ++-- 14 files changed, 30 insertions(+), 23 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 31f0ca1f..2a0dccdd 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 5.2 +current_version = 5.2.1 commit = False tag = False diff --git a/.travis.yml b/.travis.yml index 5173e569..623d7d34 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,9 @@ matrix: - name: Code quality tests env: TOXENV=flake8,docs python: 3.8 + - name: Unit tests with Python 3.9 + env: TOXENV=py39 + python: 3.9 - name: Unit tests with Python 3.8 env: TOXENV=py38 python: 3.8 @@ -36,7 +39,8 @@ script: - tox -e $TOXENV addons: - postgresql: "10" + # test with last version that still supports OIDs + postgresql: "11" services: - postgresql diff --git a/docs/about.txt b/docs/about.txt index 0274f84f..6fecd563 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -36,7 +36,7 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.2 needs PostgreSQL 9.0 to 9.6 or 10 to 12, and -Python 2.7 or 3.5 to 3.8. If you need to support older PostgreSQL versions or +The current version PyGreSQL 5.2.1 needs PostgreSQL 9.0 to 9.6 or 10 to 13, and +Python 2.7 or 3.5 to 3.9. If you need to support older PostgreSQL versions or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index 4c0194d8..a27f2c4f 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -2,11 +2,11 @@ PyGreSQL Announcements ====================== -------------------------------- -Release of PyGreSQL version 5.2 -------------------------------- +--------------------------------- +Release of PyGreSQL version 5.2.1 +--------------------------------- -Release 5.2 of PyGreSQL. +Release 5.2.1 of PyGreSQL. It is available at: https://pypi.org/project/PyGreSQL/. @@ -22,8 +22,8 @@ This version has been built and unit tested on: - openSUSE - Ubuntu - Windows 7 and 10 with both MinGW and Visual Studio - - PostgreSQL 9.0 to 9.6 and 10 to 12 (32 and 64bit) - - Python 2.7 and 3.5 to 3.8 (32 and 64bit) + - PostgreSQL 9.0 to 9.6 and 10 to 13 (32 and 64bit) + - Python 2.7 and 3.5 to 3.9 (32 and 64bit) | D'Arcy J.M. Cain | darcy@PyGreSQL.org diff --git a/docs/conf.py b/docs/conf.py index c20b8ab2..05e98a12 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,7 +68,7 @@ # built documents. # # The full version, including alpha/beta/rc tags. -version = release = '5.2' +version = release = '5.2.1' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 180710e9..8c2c279e 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -3,6 +3,7 @@ ChangeLog Version 5.2.1 (to be released) ------------------------------ +- This version officially supports the new Python 3.9 and PostgreSQL 13. - Changes to the DB-API 2 module (pgdb): - The `copy_to()` and `copy_from()` methods now also work with table names containing schema qualifiers (#47). diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 1843ef41..95c79d62 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -11,7 +11,7 @@ If you are on Windows, make sure that the directory that contains libpq.dll is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -2.7 and 3.5 to 3.8, and PostgreSQL versions 9.0 to 9.6 and 10 to 12. +2.7 and 3.5 to 3.9, and PostgreSQL versions 9.0 to 9.6 and 10 to 13. PyGreSQL will be installed as three modules, a shared library called _pg.so (on Linux) or a DLL called _pg.pyd (on Windows), and two pure diff --git a/docs/contents/pg/adaptation.rst b/docs/contents/pg/adaptation.rst index 8c09be23..b1ada9bd 100644 --- a/docs/contents/pg/adaptation.rst +++ b/docs/contents/pg/adaptation.rst @@ -363,7 +363,7 @@ With PostgreSQL we can easily calculate that these two circles overlap:: True However, calculating the intersection points between the two circles using the -``#`` operator does not work (at least not as of PostgreSQL version 12). +``#`` operator does not work (at least not as of PostgreSQL version 13). So let's resort to SymPy to find out. To ease importing circles from PostgreSQL to SymPy, we create and register the following typecast function:: diff --git a/docs/requirements.txt b/docs/requirements.txt index c354e8d9..dde5e355 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1 +1,2 @@ -cloud_sptheme>=1.7.1 \ No newline at end of file +sphinx>=3.2,<4 +cloud_sptheme>=1.10,<2 diff --git a/setup.py b/setup.py index 7ff1843a..f54f571d 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # # Please see the LICENSE.TXT file for specific restrictions. -"""Setup script for PyGreSQL version 5.2 +"""Setup script for PyGreSQL version 5.2.1 PyGreSQL is an open-source Python module that interfaces to a PostgreSQL database. It embeds the PostgreSQL query library to allow @@ -26,8 +26,8 @@ * PostgreSQL pg_config tool (usually included in the devel package) (the Windows installer has it as part of the database server feature) -PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.8, -and PostgreSQL versions 9.0 to 9.6 and 10 to 12. +PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.9, +and PostgreSQL versions 9.0 to 9.6 and 10 to 13. Use as follows: python setup.py build_ext # to build the module @@ -52,7 +52,7 @@ from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib -version = '5.2' +version = '5.2.1' if not (sys.version_info[:2] == (2, 7) or (3, 5) <= sys.version_info[:2] < (4, 0)): @@ -242,6 +242,7 @@ def finalize_options(self): 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', "Programming Language :: SQL", "Topic :: Database", "Topic :: Database :: Front-Ends", diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index c7669d61..b54b3ab8 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -195,7 +195,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 130000) + self.assertTrue(90000 <= server_version < 140000) def testAttributeSocket(self): socket = self.connection.socket diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index ee9a64fe..228a066d 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -265,7 +265,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 130000) + self.assertTrue(90000 <= server_version < 140000) self.assertEqual(server_version, self.db.db.server_version) def testAttributeSocket(self): diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 101416b8..e59828cc 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -126,7 +126,7 @@ def testPqlibVersion(self): v = pg.get_pqlib_version() self.assertIsInstance(v, long) self.assertGreater(v, 90000) - self.assertLess(v, 130000) + self.assertLess(v, 140000) class TestParseArray(unittest.TestCase): diff --git a/tox.ini b/tox.ini index bd5f222e..28d9033a 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py{27,35,36,37,38},flake8,docs +envlist = py{27,35,36,37,38,39},flake8,docs [testenv:flake8] basepython = python3.8 @@ -12,7 +12,7 @@ commands = [testenv:docs] basepython = python3.8 deps = - sphinx>=2.4,<3 + sphinx>=3.2,<4 cloud_sptheme>=1.10,<2 commands = sphinx-build -b html -nEW docs docs/_build/html From 1ed75f41febea684bfed79f332e8ee1a8751b50d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 25 Sep 2020 13:09:15 +0200 Subject: [PATCH 044/194] Add apt packages for postgres to Travis config --- .travis.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.travis.yml b/.travis.yml index 623d7d34..77d354ed 100644 --- a/.travis.yml +++ b/.travis.yml @@ -41,6 +41,10 @@ script: addons: # test with last version that still supports OIDs postgresql: "11" + apt: + packages: + - postgresql-11 + - postgresql-client-11 services: - postgresql From ff702d98d625388e90bcdef359e7d98aafa4bef2 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 25 Sep 2020 13:18:31 +0200 Subject: [PATCH 045/194] Travis: Remove old postgres before installing version 11 --- .travis.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 77d354ed..df37512b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -41,13 +41,14 @@ script: addons: # test with last version that still supports OIDs postgresql: "11" - apt: - packages: - - postgresql-11 - - postgresql-client-11 services: - postgresql +before_install: + - sudo apt-get remove -y postgresql\* + - sudo apt-get install -y postgresql-11 postgresql-client-11 + - sudo service postgresql restart 11 + before_script: - psql -U postgres -c 'create database unittest' From 54de901567c773d845759787505d94dfa6984fcd Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 25 Sep 2020 13:28:35 +0200 Subject: [PATCH 046/194] Travis: Fix port of postgres-11 installation --- .travis.yml | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/.travis.yml b/.travis.yml index df37512b..d3cdd060 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,9 +8,9 @@ matrix: - name: Code quality tests env: TOXENV=flake8,docs python: 3.8 - - name: Unit tests with Python 3.9 - env: TOXENV=py39 - python: 3.9 + # - name: Unit tests with Python 3.9 + # env: TOXENV=py39 + # python: 3.9 - name: Unit tests with Python 3.8 env: TOXENV=py38 python: 3.8 @@ -41,14 +41,17 @@ script: addons: # test with last version that still supports OIDs postgresql: "11" + apt: + packages: + - postgresql-11 + - postgresql-client-11 services: - postgresql -before_install: - - sudo apt-get remove -y postgresql\* - - sudo apt-get install -y postgresql-11 postgresql-client-11 - - sudo service postgresql restart 11 - before_script: + - sudo service postgresql stop + - sudo -u postgres sed -i "s/port = 54[0-9][0-9]/port = 5432/" /etc/postgresql/11/main/postgresql.conf + - sudo service postgresql start 11 - psql -U postgres -c 'create database unittest' + From 2aa0dad2bfec7d526874e9df0480e4d9bb18d647 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 25 Sep 2020 13:57:40 +0200 Subject: [PATCH 047/194] Travis: Create the database in a different way --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index d3cdd060..89cbb0ad 100644 --- a/.travis.yml +++ b/.travis.yml @@ -53,5 +53,6 @@ before_script: - sudo service postgresql stop - sudo -u postgres sed -i "s/port = 54[0-9][0-9]/port = 5432/" /etc/postgresql/11/main/postgresql.conf - sudo service postgresql start 11 - - psql -U postgres -c 'create database unittest' + - sudo cat /etc/postgresql/11/main/pg_hba.conf + - sudo -u postgres psql -c 'create database unittest' From 31994a905c37c1818cf1a1eb716d31fadc28db60 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 25 Sep 2020 14:11:23 +0200 Subject: [PATCH 048/194] Travis: Remove conflicting postgresql versions again --- .travis.yml | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/.travis.yml b/.travis.yml index 89cbb0ad..d264f75a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -41,18 +41,15 @@ script: addons: # test with last version that still supports OIDs postgresql: "11" - apt: - packages: - - postgresql-11 - - postgresql-client-11 services: - postgresql before_script: - sudo service postgresql stop + - sudo apt-get remove -y postgresql\* + - sudo apt-get install postgresql-11 postgresql-client-11 - sudo -u postgres sed -i "s/port = 54[0-9][0-9]/port = 5432/" /etc/postgresql/11/main/postgresql.conf - sudo service postgresql start 11 - - sudo cat /etc/postgresql/11/main/pg_hba.conf - sudo -u postgres psql -c 'create database unittest' From f288787a5c22c705c838cf21fecc2ee865c31a40 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 25 Sep 2020 14:18:31 +0200 Subject: [PATCH 049/194] Install and just warn when requesting unsupported features We check pg_config for determining whether a feature is supported. But the features may still be supported by the installed pqlib, or we may want to install the package elsewhere. This also solves problems when testing with Travis CI. --- .travis.yml | 8 +++++--- setup.py | 56 ++++++++++++++++++++++++++++++----------------------- tox.ini | 2 +- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/.travis.yml b/.travis.yml index d264f75a..dde3e96b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -39,16 +39,18 @@ script: - tox -e $TOXENV addons: - # test with last version that still supports OIDs + # last PostgreSQL version that still supports OIDs (11) postgresql: "11" + apt: + packages: + - postgresql-11 + - postgresql-server-dev-11 services: - postgresql before_script: - sudo service postgresql stop - - sudo apt-get remove -y postgresql\* - - sudo apt-get install postgresql-11 postgresql-client-11 - sudo -u postgres sed -i "s/port = 54[0-9][0-9]/port = 5432/" /etc/postgresql/11/main/postgresql.conf - sudo service postgresql start 11 - sudo -u postgres psql -c 'create database unittest' diff --git a/setup.py b/setup.py index f54f571d..dd773c82 100755 --- a/setup.py +++ b/setup.py @@ -116,18 +116,21 @@ class build_pg_ext(build_ext): ('escaping-funcs', None, "enable string escaping functions"), ('no-escaping-funcs', None, "disable string escaping functions"), ('ssl-info', None, "use new ssl info functions"), - ('no-ssl-info', None, "do not use new ssl info functions")] + ('no-ssl-info', None, "do not use new ssl info functions"), + ('memory-size', None, "enable new memory size function"), + ('no-memory-size', None, "disable new memory size function")] boolean_options = build_ext.boolean_options + [ 'strict', 'direct-access', 'large-objects', 'default-vars', - 'escaping-funcs', 'ssl-info'] + 'escaping-funcs', 'ssl-info', 'memory-size'] negative_opt = { 'no-direct-access': 'direct-access', 'no-large-objects': 'large-objects', 'no-default-vars': 'default-vars', 'no-escaping-funcs': 'escaping-funcs', - 'no-ssl-info': 'ssl-info'} + 'no-ssl-info': 'ssl-info', + 'no-memory-size': 'memory-size'} def get_compiler(self): """Return the C compiler used for building the extension.""" @@ -143,7 +146,8 @@ def initialize_options(self): self.pqlib_info = None self.ssl_info = None self.memory_size = None - if pg_version < (9, 0): + supported = pg_version >= (9, 0) + if not supported: warnings.warn( "PyGreSQL does not support the installed PostgreSQL version.") @@ -158,32 +162,36 @@ def finalize_options(self): define_macros.append(('LARGE_OBJECTS', None)) if self.default_vars is None or self.default_vars: define_macros.append(('DEFAULT_VARS', None)) - if self.escaping_funcs is None or self.escaping_funcs: - if pg_version >= (9, 0): - define_macros.append(('ESCAPING_FUNCS', None)) - else: - (warnings.warn if self.escaping_funcs is None else sys.exit)( + wanted = self.escaping_funcs + supported = pg_version >= (9, 0) + if wanted or (wanted is None and supported): + define_macros.append(('ESCAPING_FUNCS', None)) + if not supported: + warnings.warn( "The installed PostgreSQL version" " does not support the newer string escaping functions.") - if self.pqlib_info is None or self.pqlib_info: - if pg_version >= (9, 1): - define_macros.append(('PQLIB_INFO', None)) - else: - (warnings.warn if self.pqlib_info is None else sys.exit)( + wanted = self.pqlib_info + supported = pg_version >= (9, 1) + if wanted or (wanted is None and supported): + define_macros.append(('PQLIB_INFO', None)) + if not supported: + warnings.warn( "The installed PostgreSQL version" " does not support PQLib info functions.") - if self.ssl_info is None or self.ssl_info: - if pg_version >= (9, 5): - define_macros.append(('SSL_INFO', None)) - else: - (warnings.warn if self.ssl_info is None else sys.exit)( + wanted = self.ssl_info + supported = pg_version >= (9, 5) + if wanted or (wanted is None and supported): + define_macros.append(('SSL_INFO', None)) + if not supported: + warnings.warn( "The installed PostgreSQL version" " does not support SSL info functions.") - if self.memory_size is None or self.memory_size: - if pg_version >= (12, 0): - define_macros.append(('MEMORY_SIZE', None)) - else: - (warnings.warn if self.memory_size is None else sys.exit)( + wanted = self.memory_size + supported = pg_version >= (12, 0) + if wanted or (wanted is None and supported): + define_macros.append(('MEMORY_SIZE', None)) + if not supported: + warnings.warn( "The installed PostgreSQL version" " does not support the memory size function.") if sys.platform == 'win32': diff --git a/tox.ini b/tox.ini index 28d9033a..6d22a5d7 100644 --- a/tox.ini +++ b/tox.ini @@ -19,5 +19,5 @@ commands = [testenv] commands = - python setup.py clean --all build_ext --force --inplace --strict + python setup.py clean --all build_ext --force --inplace --strict --ssl-info --memory-size python -m unittest discover {posargs} From 14f7bffb5f24610bb21ab75038f1a3c6bb196a7c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 25 Sep 2020 15:55:00 +0200 Subject: [PATCH 050/194] Add release date to the changelog --- docs/contents/changelog.rst | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 8c2c279e..4a26ee68 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,12 +1,11 @@ ChangeLog ========= -Version 5.2.1 (to be released) +Version 5.2.1 (2020-09-25) ------------------------------ - This version officially supports the new Python 3.9 and PostgreSQL 13. -- Changes to the DB-API 2 module (pgdb): - - The `copy_to()` and `copy_from()` methods now also work with table names - containing schema qualifiers (#47). +- The `copy_to()` and `copy_from()` methods in the pgdb module now also work + with table names containing schema qualifiers (#47). Version 5.2 (2020-06-21) ------------------------ From 817e359c465c358a2dab51b4810d0339c70e7762 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 27 Sep 2020 00:38:35 +0200 Subject: [PATCH 051/194] Use better link for the homepage The www variant sometimes gives a certificate warning. --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index b3950fdc..ad7c9b99 100644 --- a/README.rst +++ b/README.rst @@ -24,6 +24,6 @@ see the documentation. Documentation ------------- -The documentation is available at `www.pygresql.org `_. +The documentation is available at `pygresql.org `_. At mirror of the documentation can be found at `pygresql.readthedocs.io `_. From 8a3e892c93204e4e6a111d7ee8867300b10c57ce Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 2 Oct 2020 20:32:57 +0200 Subject: [PATCH 052/194] Add missing adapter for UUIDs --- docs/contents/changelog.rst | 6 +++++- pg.py | 8 ++++++++ tests/test_classic_dbwrapper.py | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 4a26ee68..ca9e1e25 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,8 +1,12 @@ ChangeLog ========= -Version 5.2.1 (2020-09-25) +Version 5.2.2 (to be released) ------------------------------ +- Added a missing adapter method for UUIDs in the classic `pg` module. + +Version 5.2.1 (2020-09-25) +-------------------------- - This version officially supports the new Python 3.9 and PostgreSQL 13. - The `copy_to()` and `copy_from()` methods in the pgdb module now also work with table names containing schema qualifiers (#47). diff --git a/pg.py b/pg.py index f1bc62be..d3cb8902 100644 --- a/pg.py +++ b/pg.py @@ -487,6 +487,14 @@ def _adapt_hstore(self, v): return str(Hstore(v)) raise TypeError('Hstore parameter %s has wrong type' % v) + def _adapt_uuid(self, v): + """Adapt a UUID parameter.""" + if not v: + return None + if isinstance(v, basestring): + return v + return str(v) + @classmethod def _adapt_text_array(cls, v): """Adapt a text type array parameter.""" diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 228a066d..43dffaf5 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4400,6 +4400,21 @@ def testAdaptQueryTypedWithHstore(self): params[0] = ','.join(sorted(params[0].split(','))) self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) + def testAdaptQueryTypedWithUuid(self): + format_query = self.adapter.format_query + value = '12345678-1234-5678-1234-567812345678' + sql, params = format_query("select %s", (value,), 'uuid') + self.assertEqual(sql, "select $1") + self.assertEqual(params, ['12345678-1234-5678-1234-567812345678']) + value = UUID('{12345678-1234-5678-1234-567812345678}') + sql, params = format_query("select %s", (value,), 'uuid') + self.assertEqual(sql, "select $1") + self.assertEqual(params, ['12345678-1234-5678-1234-567812345678']) + value = UUID('{12345678-1234-5678-1234-567812345678}') + sql, params = format_query("select %s", (value,)) + self.assertEqual(sql, "select $1") + self.assertEqual(params, ['12345678-1234-5678-1234-567812345678']) + def testAdaptQueryTypedDict(self): format_query = self.adapter.format_query self.assertRaises( From c515f0434249e8a5db1aadec889710bb792fa0f6 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 29 Nov 2020 13:56:16 +0100 Subject: [PATCH 053/194] Performance optimization in the fetchmany() method (#51) This speeds up the creation of the result when there are many rows. --- pgdb.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pgdb.py b/pgdb.py index 976f3bdd..e139458a 100644 --- a/pgdb.py +++ b/pgdb.py @@ -1164,8 +1164,10 @@ def fetchmany(self, size=None, keep=False): except Error as err: raise _db_error(str(err)) typecast = self.type_cache.typecast - return [self.row_factory([typecast(value, typ) - for typ, value in zip(self.coltypes, row)]) for row in result] + row_factory = self.row_factory + coltypes = self.coltypes + return [row_factory([typecast(value, typ) + for typ, value in zip(coltypes, row)]) for row in result] def callproc(self, procname, parameters=None): """Call a stored database procedure with the given name. From df1043728bf727d66e35cd3bf5351989860538cd Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 29 Nov 2020 21:20:57 +0100 Subject: [PATCH 054/194] Fix reference counting issue in cast_array/record (#52) --- pgmodule.c | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pgmodule.c b/pgmodule.c index 80c3e043..ee4e101f 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -1055,12 +1055,10 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) return NULL; } - if (!cast_obj || cast_obj == Py_None) { - if (cast_obj) { - Py_DECREF(cast_obj); cast_obj = NULL; - } + if (cast_obj == Py_None) { + cast_obj = NULL; } - else if (!PyCallable_Check(cast_obj)) { + else if (cast_obj && !PyCallable_Check(cast_obj)) { PyErr_SetString( PyExc_TypeError, "Function cast_array() expects a callable as second argument"); @@ -1116,12 +1114,12 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) len = 0; } else if (cast_obj == Py_None) { - Py_DECREF(cast_obj); cast_obj = NULL; len = 0; + cast_obj = NULL; len = 0; } else if (PyTuple_Check(cast_obj) || PyList_Check(cast_obj)) { len = PySequence_Size(cast_obj); if (!len) { - Py_DECREF(cast_obj); cast_obj = NULL; + cast_obj = NULL; } } else { From f4f679338ea55855faa50bf4819ac52f731187f6 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 30 Nov 2020 21:07:53 +0100 Subject: [PATCH 055/194] Ignore incompatible libpq.dll in Windows PATH (#53) If there is a compatible libpq.dll in the Windows PATH, but an incompatible libpq.dll comes earlier in the PATH, then this will not cause an ImportError anymore. Also, the ImportError when loading the shared library will now hint to a missing/incompatible libpq as probable cause. --- pg.py | 36 +++++++++++++++++++++++++----------- pgdb.py | 36 +++++++++++++++++++++++++----------- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/pg.py b/pg.py index d3cb8902..05be042e 100644 --- a/pg.py +++ b/pg.py @@ -24,20 +24,34 @@ try: from _pg import * -except ImportError: +except ImportError as e: import os - import sys - # see https://docs.python.org/3/whatsnew/3.8.html#ctypes - if os.name == 'nt' and sys.version_info >= (3, 8): - for path in os.environ["PATH"].split(os.pathsep): - if os.path.exists(os.path.join(path, 'libpq.dll')): + libpq = 'libpq.' + if os.name == 'nt': + libpq += 'dll' + import sys + paths = [path for path in os.environ["PATH"].split(os.pathsep) + if os.path.exists(os.path.join(path, libpq))] + if sys.version_info >= (3, 8): + # see https://docs.python.org/3/whatsnew/3.8.html#ctypes + for path in paths: with os.add_dll_directory(os.path.abspath(path)): - from _pg import * - break - else: - raise + try: + from _pg import * + except ImportError: + pass + else: + e = None + break + if paths: + libpq = 'compatible ' + libpq else: - raise + libpq += 'so' + if e: + # note: we could use "raise from e" here in Python 3 + raise ImportError( + "Cannot import shared library for PyGreSQL,\n" + "probably because no %s is installed.\n%s" % (libpq, e)) __version__ = version diff --git a/pgdb.py b/pgdb.py index e139458a..7dda8126 100644 --- a/pgdb.py +++ b/pgdb.py @@ -68,20 +68,34 @@ try: from _pg import * -except ImportError: +except ImportError as e: import os - import sys - # see https://docs.python.org/3/whatsnew/3.8.html#ctypes - if os.name == 'nt' and sys.version_info >= (3, 8): - for path in os.environ["PATH"].split(os.pathsep): - if os.path.exists(os.path.join(path, 'libpq.dll')): + libpq = 'libpq.' + if os.name == 'nt': + libpq += 'dll' + import sys + paths = [path for path in os.environ["PATH"].split(os.pathsep) + if os.path.exists(os.path.join(path, libpq))] + if sys.version_info >= (3, 8): + # see https://docs.python.org/3/whatsnew/3.8.html#ctypes + for path in paths: with os.add_dll_directory(os.path.abspath(path)): - from _pg import * - break - else: - raise + try: + from _pg import * + except ImportError: + pass + else: + e = None + break + if paths: + libpq = 'compatible ' + libpq else: - raise + libpq += 'so' + if e: + # note: we could use "raise from e" here in Python 3 + raise ImportError( + "Cannot import shared library for PyGreSQL,\n" + "probably because no %s is installed.\n%s" % (libpq, e)) __version__ = version From eab73876b362ba993a654ecea971f25fc2bc453f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 30 Nov 2020 21:41:34 +0100 Subject: [PATCH 056/194] Improve installation documentation (#55) Point out the need for installation of libpq more clearly. --- docs/contents/install.rst | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 95c79d62..bc9897c7 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -7,15 +7,18 @@ General You must first install Python and PostgreSQL on your system. If you want to access remote databases only, you don't need to install the full PostgreSQL server, but only the libpq C-interface library. -If you are on Windows, make sure that the directory that contains -libpq.dll is part of your ``PATH`` environment variable. +On Windows, this library is called ``libpq.dll`` and is for instance contained +in the PostgreSQL ODBC driver (search for "psqlodbc"). On Linux, it is called +``libpq.so`` and usually provided in a package called "libpq" or "libpq5". +On Windows, you also need to make sure that the directory that contains +``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions 2.7 and 3.5 to 3.9, and PostgreSQL versions 9.0 to 9.6 and 10 to 13. PyGreSQL will be installed as three modules, a shared library called -_pg.so (on Linux) or a DLL called _pg.pyd (on Windows), and two pure -Python wrapper modules called pg.py and pgdb.py. +``_pg.so`` (on Linux) or a DLL called ``_pg.pyd`` (on Windows), and two pure +Python wrapper modules called ``pg.py`` and ``pgdb.py``. All three files will be installed directly into the Python site-packages directory. To uninstall PyGreSQL, simply remove these three files. @@ -32,6 +35,9 @@ This will automatically try to find and download a distribution on the `Python Package Index `_ that matches your operating system and Python version and install it. +Note that you still need to have the libpq interface installed on your system +(see the general remarks above). + Installing from a Binary Distribution ------------------------------------- From 9671b83e7d4063b0bf521c8701ce7f9aae4532d3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 3 Dec 2020 18:11:19 +0100 Subject: [PATCH 057/194] Optimize the type casting in the pgdb module (#51) Optimizes performance by inlining the lookup in the typecast method and making use of the fact that LocalTypecasts never give KeyErrors. This should be further optimized for the "fetchmany" case. --- pgdb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgdb.py b/pgdb.py index 7dda8126..de49651a 100644 --- a/pgdb.py +++ b/pgdb.py @@ -838,7 +838,7 @@ def get_fields(self, typ): def get_typecast(self, typ): """Get the typecast function for the given database type.""" - return self._typecasts.get(typ) + return self._typecasts[typ] def set_typecast(self, typ, cast): """Set a typecast function for the specified database type(s).""" @@ -853,7 +853,7 @@ def typecast(self, value, typ): if value is None: # for NULL values, no typecast is necessary return None - cast = self.get_typecast(typ) + cast = self._typecasts[typ] if not cast or cast is str: # no typecast is necessary return value From 76db14da7862bd9677e2b45334f459ad231b335c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 3 Dec 2020 19:56:50 +0100 Subject: [PATCH 058/194] Optimize fetchall() when many rows are fetched (#51) In this case, lookup all the type cast functions upfront. This gives more than 30% better performance for large queries results. --- pgdb.py | 23 ++++++++++++++++++++--- tests/test_dbapi20.py | 24 ++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/pgdb.py b/pgdb.py index de49651a..046df7b7 100644 --- a/pgdb.py +++ b/pgdb.py @@ -854,11 +854,23 @@ def typecast(self, value, typ): # for NULL values, no typecast is necessary return None cast = self._typecasts[typ] - if not cast or cast is str: + if cast is None or cast is str: # no typecast is necessary return value return cast(value) + def get_row_caster(self, types): + """Get a typecast function for a complete row of values.""" + typecasts = self._typecasts + casts = [typecasts[typ] for typ in types] + casts = [cast if cast is not str else None for cast in casts] + + def row_caster(row): + return [value if cast is None or value is None else cast(value) + for cast, value in zip(casts, row)] + + return row_caster + class _quotedict(dict): """Dictionary with auto quoting of its items. @@ -1177,10 +1189,15 @@ def fetchmany(self, size=None, keep=False): raise except Error as err: raise _db_error(str(err)) - typecast = self.type_cache.typecast row_factory = self.row_factory coltypes = self.coltypes - return [row_factory([typecast(value, typ) + if len(result) > 5: + # optimize the case where we really fetch many values + # by looking up all type casting functions upfront + cast_row = self.type_cache.get_row_caster(coltypes) + return [row_factory(cast_row(row)) for row in result] + cast_value = self.type_cache.typecast + return [row_factory([cast_value(value, typ) for typ, value in zip(coltypes, row)]) for row in result] def callproc(self, procname, parameters=None): diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index aadde468..3dad0d7a 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1137,6 +1137,30 @@ def test_execute_edge_cases(self): sql = 'select 1' # cannot be executed after connection is closed self.assertRaises(pgdb.OperationalError, cur.execute, sql) + def test_fetchall_with_various_sizes(self): + # we test this because there are optimizations based on result size + con = self._connect() + try: + for n in (1, 3, 5, 7, 10, 100, 1000): + cur = con.cursor() + try: + cur.execute('select n, n::text as s, n %% 2 = 1 as b' + ' from generate_series(1, %d) as s(n)' % n) + res = cur.fetchall() + self.assertEqual(len(res), n, res) + self.assertEqual(len(res[0]), 3) + self.assertEqual(res[0].n, 1) + self.assertEqual(res[0].s, '1') + self.assertIs(res[0].b, True) + self.assertEqual(len(res[-1]), 3) + self.assertEqual(res[-1].n, n) + self.assertEqual(res[-1].s, str(n)) + self.assertIs(res[-1].b, n % 2 == 1) + finally: + cur.close() + finally: + con.close() + def test_fetchmany_with_keep(self): con = self._connect() try: From 8d38103424aea11ef1888cae3cd5c51c02325ebe Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 9 Dec 2020 11:01:35 +0100 Subject: [PATCH 059/194] Update the changelog --- docs/contents/changelog.rst | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index ca9e1e25..3389e0b7 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,9 +1,12 @@ ChangeLog ========= -Version 5.2.2 (to be released) ------------------------------- +Version 5.2.2 (2020-12-09) +-------------------------- - Added a missing adapter method for UUIDs in the classic `pg` module. +- Performance optimizations for `fetchmany()` in the `pgdb` module (#51). +- Fixed a reference counting issues in the `cast_array/record` methods (#52). +- Ignore incompatible libpq.dll in Windows PATH for Python >= 3.8 (#53). Version 5.2.1 (2020-09-25) -------------------------- @@ -97,15 +100,15 @@ Version 5.1 (2019-05-17) no transactions are started and calling commit() is not required. Note that this is not part of the DB-API 2 standard. -Vesion 5.0.7 (2019-05-17) -------------------------- +Version 5.0.7 (2019-05-17) +-------------------------- - This version officially supports the new PostgreSQL 11. - Fixed a bug in parsing array subscript ranges (reported by Justin Pryzby). - Fixed an issue when deleting a DB wrapper object with the underlying connection already closed (bug report by Jacob Champion). -Vesion 5.0.6 (2018-07-29) -------------------------- +Version 5.0.6 (2018-07-29) +-------------------------- - This version officially supports the new Python 3.7. - Correct trove classifier for the PostgreSQL License. From d3ed20931daf571dd3fc8527faf00e2091d14417 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 9 Dec 2020 11:42:33 +0100 Subject: [PATCH 060/194] Make flaky notification test more robust --- tests/test_classic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_classic.py b/tests/test_classic.py index 15db5060..16deffee 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -302,12 +302,12 @@ def test_notify_timeout(self): NotificationHandler, db) arg_dict = dict(event=None, called=False) self.notify_timeout = False - # Listen for 'event_1' with timeout of 10ms. - target = fut('event_1', self.notify_callback, arg_dict, 0.01) + # Listen for 'event_1' with timeout of 50ms. + target = fut('event_1', self.notify_callback, arg_dict, 0.05) thread = Thread(None, target) thread.start() - # Sleep 20ms, long enough to time out. - sleep(0.02) + # Sleep 100ms, long enough to time out. + sleep(0.1) # Verify that we've indeed timed out. self.assertFalse(arg_dict.get('called')) self.assertTrue(self.notify_timeout) From cecc50089ecf63877efb300e4f7cd6c1def5bc91 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 9 Dec 2020 11:43:32 +0100 Subject: [PATCH 061/194] InternalError was missing in list of exported pgdb symbols --- pgdb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgdb.py b/pgdb.py index 046df7b7..3eee8de1 100644 --- a/pgdb.py +++ b/pgdb.py @@ -111,7 +111,7 @@ 'UUID', 'HSTORE', 'JSON', 'ARRAY', 'RECORD', 'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError', - 'IntegrityError', 'ProgrammingError', 'NotSupportedError', + 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', 'apilevel', 'connect', 'paramstyle', 'threadsafety', 'get_typecast', 'set_typecast', 'reset_typecast', 'version', '__version__'] From f361f854262066804c6178e36a3bc5306503df4c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 9 Dec 2020 11:52:23 +0100 Subject: [PATCH 062/194] Make flaky notification test even more robust --- tests/test_classic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_classic.py b/tests/test_classic.py index 16deffee..c3b731d1 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -306,8 +306,8 @@ def test_notify_timeout(self): target = fut('event_1', self.notify_callback, arg_dict, 0.05) thread = Thread(None, target) thread.start() - # Sleep 100ms, long enough to time out. - sleep(0.1) + # Sleep 250ms, long enough to time out. + sleep(0.25) # Verify that we've indeed timed out. self.assertFalse(arg_dict.get('called')) self.assertTrue(self.notify_timeout) From 0f9275ae861aef7e4dd159db24c3e136a4419de8 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 9 Dec 2020 11:55:19 +0100 Subject: [PATCH 063/194] Bump version number --- .bumpversion.cfg | 2 +- docs/about.txt | 2 +- docs/announce.rst | 4 ++-- docs/conf.py | 2 +- setup.py | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 2a0dccdd..24ad1614 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 5.2.1 +current_version = 5.2.2 commit = False tag = False diff --git a/docs/about.txt b/docs/about.txt index 6fecd563..bff0af4f 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -36,7 +36,7 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.2.1 needs PostgreSQL 9.0 to 9.6 or 10 to 13, and +The current version PyGreSQL 5.2.2 needs PostgreSQL 9.0 to 9.6 or 10 to 13, and Python 2.7 or 3.5 to 3.9. If you need to support older PostgreSQL versions or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index a27f2c4f..6fc4ac49 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -3,10 +3,10 @@ PyGreSQL Announcements ====================== --------------------------------- -Release of PyGreSQL version 5.2.1 +Release of PyGreSQL version 5.2.2 --------------------------------- -Release 5.2.1 of PyGreSQL. +Release 5.2.2 of PyGreSQL. It is available at: https://pypi.org/project/PyGreSQL/. diff --git a/docs/conf.py b/docs/conf.py index 05e98a12..b8eafe7d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,7 +68,7 @@ # built documents. # # The full version, including alpha/beta/rc tags. -version = release = '5.2.1' +version = release = '5.2.2' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/setup.py b/setup.py index dd773c82..ed02c404 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # # Please see the LICENSE.TXT file for specific restrictions. -"""Setup script for PyGreSQL version 5.2.1 +"""Setup script for PyGreSQL version 5.2.2 PyGreSQL is an open-source Python module that interfaces to a PostgreSQL database. It embeds the PostgreSQL query library to allow @@ -52,7 +52,7 @@ from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib -version = '5.2.1' +version = '5.2.2' if not (sys.version_info[:2] == (2, 7) or (3, 5) <= sys.version_info[:2] < (4, 0)): From c6be423e12ffe7e0887fdadc8362771d0f38a07a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 9 Dec 2020 15:09:52 +0100 Subject: [PATCH 064/194] Fix typo in ChangeLog --- docs/contents/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 3389e0b7..fec3199f 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -5,7 +5,7 @@ Version 5.2.2 (2020-12-09) -------------------------- - Added a missing adapter method for UUIDs in the classic `pg` module. - Performance optimizations for `fetchmany()` in the `pgdb` module (#51). -- Fixed a reference counting issues in the `cast_array/record` methods (#52). +- Fixed a reference counting issue in the `cast_array/record` methods (#52). - Ignore incompatible libpq.dll in Windows PATH for Python >= 3.8 (#53). Version 5.2.1 (2020-09-25) From 9facbfca1987c87d546e64cd84b5311d29bf879b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 8 Jan 2021 22:46:48 +0100 Subject: [PATCH 065/194] Fix refcounting issue when casting from JSON (#57) --- pginternal.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pginternal.c b/pginternal.c index e1d36692..25e1dcc8 100644 --- a/pginternal.c +++ b/pginternal.c @@ -240,8 +240,8 @@ cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) /* this type should only be passed when jsondecode is set */ obj = get_decoded_string(s, size, encoding); if (obj && jsondecode) { /* was able to decode */ - tmp_obj = Py_BuildValue("(O)", obj); - obj = PyObject_CallObject(jsondecode, tmp_obj); + tmp_obj = obj; + obj = PyObject_CallFunction(jsondecode, "(O)", obj); Py_DECREF(tmp_obj); } break; From a29619bdff24742f413467e829047a39a56245a3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 8 Jan 2021 23:15:57 +0100 Subject: [PATCH 066/194] Update changelog --- docs/contents/changelog.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index fec3199f..6c2ba061 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,10 @@ ChangeLog ========= +Version 5.2.3 (to be released) +------------------------------ +- Fixed a reference counting issue when casting JSON columns (#57). + Version 5.2.2 (2020-12-09) -------------------------- - Added a missing adapter method for UUIDs in the classic `pg` module. From 6969e74a7020e502dbec739457ccdc73a5fca020 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 24 Jan 2021 20:31:39 +0100 Subject: [PATCH 067/194] Add missing get/set_typecasts in list of exports --- pg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pg.py b/pg.py index 05be042e..6c1bd35c 100644 --- a/pg.py +++ b/pg.py @@ -75,12 +75,12 @@ 'get_array', 'get_bool', 'get_bytea_escaped', 'get_datestyle', 'get_decimal', 'get_decimal_point', 'get_defbase', 'get_defhost', 'get_defopt', 'get_defport', 'get_defuser', - 'get_jsondecode', + 'get_jsondecode', 'get_typecast', 'set_array', 'set_bool', 'set_bytea_escaped', 'set_datestyle', 'set_decimal', 'set_decimal_point', 'set_defbase', 'set_defhost', 'set_defopt', 'set_defpasswd', 'set_defport', 'set_defuser', - 'set_jsondecode', 'set_query_helpers', + 'set_jsondecode', 'set_query_helpers', 'set_typecast', 'version', '__version__'] import select From 16804ee0abf27ff512d07e19bbf5d10308a6b7cd Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 28 Mar 2021 15:11:08 +0200 Subject: [PATCH 068/194] Fix argument handling of is/set_non_blocking() --- pgconn.c | 27 ++++++++------------------- tests/test_classic_connection.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/pgconn.c b/pgconn.c index ecf18a76..302dd29d 100644 --- a/pgconn.c +++ b/pgconn.c @@ -641,7 +641,9 @@ conn_set_non_blocking(connObject *self, PyObject *args) } if (!PyArg_ParseTuple(args, "i", &non_blocking)) { - PyErr_SetString(PyExc_TypeError, "setnonblocking(tf), with boolean."); + PyErr_SetString( + PyExc_TypeError, + "set_non_blocking() expects a boolean value as argument"); return NULL; } @@ -658,7 +660,7 @@ static char conn_is_non_blocking__doc__[] = "is_non_blocking() -- report the blocking status of the connection"; static PyObject * -conn_is_non_blocking(connObject *self, PyObject *args) +conn_is_non_blocking(connObject *self, PyObject *noargs) { int rc; @@ -667,19 +669,13 @@ conn_is_non_blocking(connObject *self, PyObject *args) return NULL; } - if (!PyArg_ParseTuple(args, "")) { - PyErr_SetString(PyExc_TypeError, - "method is_non_blocking() takes no parameters"); - return NULL; - } - rc = PQisnonblocking(self->cnx); if (rc < 0) { PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); return NULL; } - return PyBool_FromLong(rc); + return PyBool_FromLong((long)rc); } #endif /* DIRECT_ACCESS */ @@ -1426,7 +1422,7 @@ static char conn_poll__doc__[] = "poll() -- Completes an asynchronous connection"; static PyObject * -conn_poll(connObject *self, PyObject *args) +conn_poll(connObject *self, PyObject *noargs) { int rc; @@ -1435,13 +1431,6 @@ conn_poll(connObject *self, PyObject *args) return NULL; } - /* check args */ - if (!PyArg_ParseTuple(args, "")) { - PyErr_SetString(PyExc_TypeError, - "method poll() takes no parameters"); - return NULL; - } - Py_BEGIN_ALLOW_THREADS rc = PQconnectPoll(self->cnx); Py_END_ALLOW_THREADS @@ -1612,7 +1601,7 @@ static struct PyMethodDef conn_methods[] = { {"describe_prepared", (PyCFunction) conn_describe_prepared, METH_VARARGS, conn_describe_prepared__doc__}, {"poll", (PyCFunction) conn_poll, - METH_VARARGS, conn_poll__doc__}, + METH_NOARGS, conn_poll__doc__}, {"reset", (PyCFunction) conn_reset, METH_NOARGS, conn_reset__doc__}, {"cancel", (PyCFunction) conn_cancel, @@ -1659,7 +1648,7 @@ static struct PyMethodDef conn_methods[] = { {"endcopy", (PyCFunction) conn_endcopy, METH_NOARGS, conn_endcopy__doc__}, {"set_non_blocking", (PyCFunction) conn_set_non_blocking, - METH_O, conn_set_non_blocking__doc__}, + METH_VARARGS, conn_set_non_blocking__doc__}, {"is_non_blocking", (PyCFunction) conn_is_non_blocking, METH_NOARGS, conn_is_non_blocking__doc__}, #endif /* DIRECT_ACCESS */ diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index b54b3ab8..4bb7336c 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -84,8 +84,15 @@ class TestCanConnect(unittest.TestCase): def testCanConnect(self): try: connection = connect() + rc = connection.poll() except pg.Error as error: self.fail('Cannot connect to database %s:\n%s' % (dbname, error)) + self.assertEqual(rc, pg.POLLING_OK) + self.assertIs(connection.is_non_blocking(), False) + connection.set_non_blocking(True) + self.assertIs(connection.is_non_blocking(), True) + connection.set_non_blocking(False) + self.assertIs(connection.is_non_blocking(), False) try: connection.close() except pg.Error: @@ -93,13 +100,19 @@ def testCanConnect(self): def testCanConnectNoWait(self): try: - connection = connect() + connection = connect_nowait() rc = connection.poll() + self.assertEqual(rc, pg.POLLING_READING) while rc not in (pg.POLLING_OK, pg.POLLING_FAILED): rc = connection.poll() except pg.Error as error: self.fail('Cannot connect to database %s:\n%s' % (dbname, error)) self.assertEqual(rc, pg.POLLING_OK) + self.assertIs(connection.is_non_blocking(), False) + connection.set_non_blocking(True) + self.assertIs(connection.is_non_blocking(), True) + connection.set_non_blocking(False) + self.assertIs(connection.is_non_blocking(), False) try: connection.close() except pg.Error: From 2c61c3e4e67b8c5b81fd299c8dc7587d12742ec7 Mon Sep 17 00:00:00 2001 From: Justin Pryzby Date: Sat, 27 Mar 2021 21:08:57 -0500 Subject: [PATCH 069/194] doc fixes --- docs/contents/pg/connection.rst | 6 +++--- pgconn.c | 2 +- pgquery.c | 10 ++++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 7d721b89..cd5a016c 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -103,12 +103,12 @@ returns without waiting for the query to complete. The database connection cannot be used for other operations until the query completes, but the application can do other things, including executing queries using other database connections. The application can call ``select()`` using the -``fileno``` obtained by the connection#s :meth:`Connection.fileno` method +``fileno`` obtained by the connection's :meth:`Connection.fileno` method to determine when the query has results to return. This method always returns a :class:`Query` object. This object differs from the :class:`Query` object returned by :meth:`Connection.query` in a -few way. Most importantly, when :meth:`Connection.send_query` is used, the +few ways. Most importantly, when :meth:`Connection.send_query` is used, the application must call one of the result-returning methods such as :meth:`Query.getresult` or :meth:`Query.dictresult` until it either raises an exception or returns ``None``. @@ -285,7 +285,7 @@ it's no different from a connection made using blocking calls. The required steps are to pass the parameter ``nowait=True`` to the :meth:`pg.connect` call, then call :meth:`Connection.poll` until it either -returns :const':`POLLING_OK` or raises an exception. To avoid blocking +returns :const:`POLLING_OK` or raises an exception. To avoid blocking in :meth:`Connection.poll`, use `select()` or `poll()` to wait for the connection to be readable or writable, depending on the return code of the previous call to :meth:`Connection.poll`. The initial state of the connection diff --git a/pgconn.c b/pgconn.c index 302dd29d..958131c6 100644 --- a/pgconn.c +++ b/pgconn.c @@ -511,7 +511,7 @@ conn_describe_prepared(connObject *self, PyObject *args) /* reads args */ if (!PyArg_ParseTuple(args, "s#", &name, &name_length)) { PyErr_SetString(PyExc_TypeError, - "Method prepare() takes a string argument"); + "Method describe_prepared() takes a string argument"); return NULL; } diff --git a/pgquery.c b/pgquery.c index 50f837bb..a5dc34c5 100644 --- a/pgquery.c +++ b/pgquery.c @@ -223,8 +223,10 @@ static PyObject* query_iter(queryObject *self) return (PyObject*) self; } -/* __next__() method of the queryObject: - Returns the current current row as a tuple and moves to the next one. */ +/* + * __next__() method of the queryObject: + * Returns the current row as a tuple and moves to the next one. + */ static PyObject * query_next(queryObject *self, PyObject *noargs) { @@ -357,7 +359,7 @@ _query_build_field_info(PGresult *res, int col_num) { /* Get information on one or all fields of the query result. */ static char query_fieldinfo__doc__[] = -"fieldinfo() -- return info on field(s) in query"; +"fieldinfo([name]) -- return information about field(s) in query result"; static PyObject * query_fieldinfo(queryObject *self, PyObject *args) @@ -522,7 +524,7 @@ _query_row_as_dict(queryObject *self) return row_dict; } -/* Return the current current row as a dict and move to the next one. */ +/* Return the current row as a dict and move to the next one. */ static PyObject * query_next_dict(queryObject *self, PyObject *noargs) { From 164399a9f715bcb211834f2cc4387c4ba9d5659d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 28 Mar 2021 16:21:28 +0200 Subject: [PATCH 070/194] Use consistent style of comment formatting --- pgquery.c | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pgquery.c b/pgquery.c index a5dc34c5..63b86c32 100644 --- a/pgquery.c +++ b/pgquery.c @@ -223,10 +223,8 @@ static PyObject* query_iter(queryObject *self) return (PyObject*) self; } -/* - * __next__() method of the queryObject: - * Returns the current row as a tuple and moves to the next one. - */ +/* __next__() method of the queryObject: + Returns the current row as a tuple and moves to the next one. */ static PyObject * query_next(queryObject *self, PyObject *noargs) { @@ -359,7 +357,7 @@ _query_build_field_info(PGresult *res, int col_num) { /* Get information on one or all fields of the query result. */ static char query_fieldinfo__doc__[] = -"fieldinfo([name]) -- return information about field(s) in query result"; +"fieldinfo([name]) -- return information about field(s) in query result"; static PyObject * query_fieldinfo(queryObject *self, PyObject *args) @@ -524,7 +522,7 @@ _query_row_as_dict(queryObject *self) return row_dict; } -/* Return the current row as a dict and move to the next one. */ +/* Return the current row as a dict and move to the next one. */ static PyObject * query_next_dict(queryObject *self, PyObject *noargs) { From fef206b7137dbbb029d2923840d4023243f3b449 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 2 Sep 2021 19:42:27 +0200 Subject: [PATCH 071/194] Fix compilation issue on cp310-win #63 --- pgconn.c | 4 ++-- tox.ini | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pgconn.c b/pgconn.c index 958131c6..dd6017bc 100644 --- a/pgconn.c +++ b/pgconn.c @@ -511,7 +511,7 @@ conn_describe_prepared(connObject *self, PyObject *args) /* reads args */ if (!PyArg_ParseTuple(args, "s#", &name, &name_length)) { PyErr_SetString(PyExc_TypeError, - "Method describe_prepared() takes a string argument"); + "Method describe_prepared() takes a string argument"); return NULL; } @@ -769,7 +769,7 @@ conn_inserttable(connObject *self, PyObject *args) bufpt += sprintf(bufpt, " ("); for (j = 0; j < n; ++j) { PyObject *obj = getcolumn(columns, j); - ssize_t slen; + Py_ssize_t slen; char *col; if (PyBytes_Check(obj)) { diff --git a/tox.ini b/tox.ini index 6d22a5d7..fa945503 100644 --- a/tox.ini +++ b/tox.ini @@ -1,16 +1,16 @@ # config file for tox [tox] -envlist = py{27,35,36,37,38,39},flake8,docs +envlist = py{27,35,36,37,38,39,310},flake8,docs [testenv:flake8] -basepython = python3.8 +basepython = python3.9 deps = flake8>=3.8,<4 commands = flake8 setup.py pg.py pgdb.py tests [testenv:docs] -basepython = python3.8 +basepython = python3.9 deps = sphinx>=3.2,<4 cloud_sptheme>=1.10,<2 From a0025f85b2062d48c36d4c04f301850dcc261022 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 28 Jan 2022 21:46:32 +0100 Subject: [PATCH 072/194] Catch buffer overflows in inserttable function These could happen if the columns specification was extremely large. --- pgconn.c | 18 +++++++++++++----- tests/test_classic_connection.py | 5 +++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/pgconn.c b/pgconn.c index dd6017bc..383a48bd 100644 --- a/pgconn.c +++ b/pgconn.c @@ -690,7 +690,7 @@ static PyObject * conn_inserttable(connObject *self, PyObject *args) { PGresult *result; - char *table, *buffer, *bufpt; + char *table, *buffer, *bufpt, *bufmax; int encoding; size_t bufsiz; PyObject *list, *sublist, *item, *columns = NULL; @@ -761,12 +761,14 @@ conn_inserttable(connObject *self, PyObject *args) /* starts query */ bufpt = buffer; + bufmax = bufpt + MAX_BUFFER_SIZE; table = PQescapeIdentifier(self->cnx, table, strlen(table)); - bufpt += sprintf(bufpt, "copy %s", table); + bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "copy %s", table); PQfreemem(table); if (columns) { /* adds a string like f" ({','.join(columns)})" */ - bufpt += sprintf(bufpt, " ("); + if (bufpt < bufmax) + bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), " ("); for (j = 0; j < n; ++j) { PyObject *obj = getcolumn(columns, j); Py_ssize_t slen; @@ -786,11 +788,17 @@ conn_inserttable(connObject *self, PyObject *args) "The third argument must contain only strings"); } col = PQescapeIdentifier(self->cnx, col, (size_t) slen); - bufpt += sprintf(bufpt, "%s%s", col, j == n - 1 ? ")" : ","); + if (bufpt < bufmax) + bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), + "%s%s", col, j == n - 1 ? ")" : ","); PQfreemem(col); } } - sprintf(bufpt, " from stdin"); + if (bufpt < bufmax) + snprintf(bufpt, (size_t) (bufmax - bufpt), " from stdin"); + if (bufpt >= bufmax) { + PyMem_Free(buffer); return PyErr_NoMemory(); + } Py_BEGIN_ALLOW_THREADS result = PQexec(self->cnx, buffer); diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 4bb7336c..b51a20f8 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2075,6 +2075,11 @@ def testInserttableNoEncoding(self): # cannot encode non-ascii unicode without a specific encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) + def testInserttableTooLargeColumnSpecification(self): + # should catch buffer overflow when building the column specification + self.assertRaises(MemoryError, self.c.inserttable, + 'test', [], ['very_long_column_name'] * 1000) + class TestDirectSocketAccess(unittest.TestCase): """Test copy command with direct socket access.""" From 8b777081dc3a29aace86ed94b47182ab4371247e Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 29 Jan 2022 00:46:58 +0100 Subject: [PATCH 073/194] Make inserttable() accept any iterable (#66) --- docs/contents/pg/connection.rst | 28 +++---- pgconn.c | 130 ++++++++++++++++--------------- tests/test_classic_connection.py | 37 +++++++-- 3 files changed, 112 insertions(+), 83 deletions(-) diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index cd5a016c..a9fccfdf 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -103,12 +103,12 @@ returns without waiting for the query to complete. The database connection cannot be used for other operations until the query completes, but the application can do other things, including executing queries using other database connections. The application can call ``select()`` using the -``fileno`` obtained by the connection's :meth:`Connection.fileno` method +``fileno`` obtained by the connection's :meth:`Connection.fileno` method to determine when the query has results to return. This method always returns a :class:`Query` object. This object differs from the :class:`Query` object returned by :meth:`Connection.query` in a -few ways. Most importantly, when :meth:`Connection.send_query` is used, the +few ways. Most importantly, when :meth:`Connection.send_query` is used, the application must call one of the result-returning methods such as :meth:`Query.getresult` or :meth:`Query.dictresult` until it either raises an exception or returns ``None``. @@ -285,7 +285,7 @@ it's no different from a connection made using blocking calls. The required steps are to pass the parameter ``nowait=True`` to the :meth:`pg.connect` call, then call :meth:`Connection.poll` until it either -returns :const:`POLLING_OK` or raises an exception. To avoid blocking +returns :const:`POLLING_OK` or raises an exception. To avoid blocking in :meth:`Connection.poll`, use `select()` or `poll()` to wait for the connection to be readable or writable, depending on the return code of the previous call to :meth:`Connection.poll`. The initial state of the connection @@ -484,27 +484,27 @@ first, otherwise :meth:`Connection.getnotify` will always return ``None``. .. versionchanged:: 4.1 Support for payload strings was added in version 4.1. -inserttable -- insert a list into a table ------------------------------------------ +inserttable -- insert an iterable into a table +---------------------------------------------- .. method:: Connection.inserttable(table, values, [columns]) - Insert a Python list into a database table + Insert a Python iterable into a database table :param str table: the table name - :param list values: list of rows values - :param list columns: list of column names + :param list values: iterable of row values, which must be lists or tuples + :param list columns: list or tuple of column names :rtype: None :raises TypeError: invalid connection, bad argument type, or too many arguments :raises MemoryError: insert buffer could not be allocated :raises ValueError: unsupported values -This method allows to *quickly* insert large blocks of data in a table: -It inserts the whole values list into the given table. Internally, it -uses the COPY command of the PostgreSQL database. The list is a list -of tuples/lists that define the values for each inserted row. The rows -values may contain string, integer, long or double (real) values. -``columns`` is an optional sequence of column names to be passed on +This method allows to *quickly* insert large blocks of data in a table. +Internally, it uses the COPY command of the PostgreSQL database. +The method takes an iterable of row values which must be tuples or lists +of the same size, containing the values for each inserted row. +These may contain string, integer, long or double (real) values. +``columns`` is an optional tuple or list of column names to be passed on to the COPY command. .. warning:: diff --git a/pgconn.c b/pgconn.c index 383a48bd..5bdee3a7 100644 --- a/pgconn.c +++ b/pgconn.c @@ -682,9 +682,9 @@ conn_is_non_blocking(connObject *self, PyObject *noargs) /* Insert table */ static char conn_inserttable__doc__[] = -"inserttable(table, data, [columns]) -- insert list into table\n\n" -"The fields in the list must be in the same order as in the table\n" -"or in the list of columns if one is specified.\n"; +"inserttable(table, data, [columns]) -- insert iterable into table\n\n" +"The fields in the iterable must be in the same order as in the table\n" +"or in the list or tuple of columns if one is specified.\n"; static PyObject * conn_inserttable(connObject *self, PyObject *args) @@ -693,11 +693,8 @@ conn_inserttable(connObject *self, PyObject *args) char *table, *buffer, *bufpt, *bufmax; int encoding; size_t bufsiz; - PyObject *list, *sublist, *item, *columns = NULL; - PyObject *(*getitem) (PyObject *, Py_ssize_t); - PyObject *(*getsubitem) (PyObject *, Py_ssize_t); - PyObject *(*getcolumn) (PyObject *, Py_ssize_t); - Py_ssize_t i, j, m, n = 0; + PyObject *rows, *iter_row, *item, *columns = NULL; + Py_ssize_t i, j, m, n; if (!self->cnx) { PyErr_SetString(PyExc_TypeError, "Connection is not valid"); @@ -705,7 +702,7 @@ conn_inserttable(connObject *self, PyObject *args) } /* gets arguments */ - if (!PyArg_ParseTuple(args, "sO|O", &table, &list, &columns)) { + if (!PyArg_ParseTuple(args, "sO|O", &table, &rows, &columns)) { PyErr_SetString( PyExc_TypeError, "Method inserttable() expects a string and a list as arguments"); @@ -713,49 +710,43 @@ conn_inserttable(connObject *self, PyObject *args) } /* checks list type */ - if (PyList_Check(list)) { - m = PyList_Size(list); - getitem = PyList_GetItem; - } - else if (PyTuple_Check(list)) { - m = PyTuple_Size(list); - getitem = PyTuple_GetItem; - } - else { + if (!(iter_row = PyObject_GetIter(rows))) + { PyErr_SetString( PyExc_TypeError, - "Method inserttable() expects a list or a tuple" + "Method inserttable() expects an iterable" " as second argument"); return NULL; } + m = PySequence_Check(rows) ? PySequence_Size(rows) : -1; + if (!m) { + /* no rows specified, nothing to do */ + Py_DECREF(iter_row); Py_INCREF(Py_None); return Py_None; + } /* checks columns type */ if (columns) { - if (PyList_Check(columns)) { - n = PyList_Size(columns); - getcolumn = PyList_GetItem; - } - else if (PyTuple_Check(columns)) { - n = PyTuple_Size(columns); - getcolumn = PyTuple_GetItem; - } - else { + if (!(PyTuple_Check(columns) || PyList_Check(columns))) { PyErr_SetString( PyExc_TypeError, - "Method inserttable() expects a list or a tuple" - " as third argument"); + "Method inserttable() expects a tuple or a list" + " as second argument"); return NULL; } + + n = PySequence_Fast_GET_SIZE(columns); if (!n) { /* no columns specified, nothing to do */ - Py_INCREF(Py_None); - return Py_None; + Py_DECREF(iter_row); Py_INCREF(Py_None); return Py_None; } + } else { + n = -1; /* number of columns not yet known */ } /* allocate buffer */ - if (!(buffer = PyMem_Malloc(MAX_BUFFER_SIZE))) - return PyErr_NoMemory(); + if (!(buffer = PyMem_Malloc(MAX_BUFFER_SIZE))) { + Py_DECREF(iter_row); return PyErr_NoMemory(); + } encoding = PQclientEncoding(self->cnx); @@ -770,7 +761,7 @@ conn_inserttable(connObject *self, PyObject *args) if (bufpt < bufmax) bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), " ("); for (j = 0; j < n; ++j) { - PyObject *obj = getcolumn(columns, j); + PyObject *obj = PySequence_Fast_GET_ITEM(columns, j); Py_ssize_t slen; char *col; @@ -779,13 +770,18 @@ conn_inserttable(connObject *self, PyObject *args) } else if (PyUnicode_Check(obj)) { obj = get_encoded_string(obj, encoding); - if (!obj) return NULL; /* pass the UnicodeEncodeError */ + if (!obj) { + Py_DECREF(iter_row); + return NULL; /* pass the UnicodeEncodeError */ + } PyBytes_AsStringAndSize(obj, &col, &slen); Py_DECREF(obj); } else { PyErr_SetString( PyExc_TypeError, "The third argument must contain only strings"); + Py_DECREF(iter_row); + return NULL; } col = PQescapeIdentifier(self->cnx, col, (size_t) slen); if (bufpt < bufmax) @@ -797,7 +793,8 @@ conn_inserttable(connObject *self, PyObject *args) if (bufpt < bufmax) snprintf(bufpt, (size_t) (bufmax - bufpt), " from stdin"); if (bufpt >= bufmax) { - PyMem_Free(buffer); return PyErr_NoMemory(); + PyMem_Free(buffer); Py_DECREF(iter_row); + return PyErr_NoMemory(); } Py_BEGIN_ALLOW_THREADS @@ -805,7 +802,7 @@ conn_inserttable(connObject *self, PyObject *args) Py_END_ALLOW_THREADS if (!result) { - PyMem_Free(buffer); + PyMem_Free(buffer); Py_DECREF(iter_row); PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); return NULL; } @@ -813,33 +810,29 @@ conn_inserttable(connObject *self, PyObject *args) PQclear(result); /* feed table */ - for (i = 0; i < m; ++i) { - sublist = getitem(list, i); - if (PyTuple_Check(sublist)) { - j = PyTuple_Size(sublist); - getsubitem = PyTuple_GetItem; - } - else if (PyList_Check(sublist)) { - j = PyList_Size(sublist); - getsubitem = PyList_GetItem; - } - else { + for (i = 0; m < 0 || i < m; ++i) { + + if (!(columns = PyIter_Next(iter_row))) break; + + if (!(PyTuple_Check(columns) || PyList_Check(columns))) { + PyMem_Free(buffer); + Py_DECREF(columns); Py_DECREF(columns); Py_DECREF(iter_row); PyErr_SetString( PyExc_TypeError, - "The second argument must contain a tuple or a list"); + "The second argument must contain tuples or lists"); return NULL; } - if (i) { - if (j != n) { - PyMem_Free(buffer); - PyErr_SetString( - PyExc_TypeError, - "Arrays contained in second arg must have same size"); - return NULL; - } - } - else { - n = j; /* never used before this assignment */ + + j = PySequence_Fast_GET_SIZE(columns); + if (n < 0) { + n = j; + } else if (j != n) { + PyMem_Free(buffer); + Py_DECREF(columns); Py_DECREF(iter_row); + PyErr_SetString( + PyExc_TypeError, + "The second arg must contain sequences of the same size"); + return NULL; } /* builds insert line */ @@ -851,7 +844,7 @@ conn_inserttable(connObject *self, PyObject *args) *bufpt++ = '\t'; --bufsiz; } - item = getsubitem(sublist, j); + item = PySequence_Fast_GET_ITEM(columns, j); /* convert item to string and append to buffer */ if (item == Py_None) { @@ -877,6 +870,7 @@ conn_inserttable(connObject *self, PyObject *args) PyObject *s = get_encoded_string(item, encoding); if (!s) { PyMem_Free(buffer); + Py_DECREF(item); Py_DECREF(columns); Py_DECREF(iter_row); return NULL; /* pass the UnicodeEncodeError */ } else { @@ -916,22 +910,30 @@ conn_inserttable(connObject *self, PyObject *args) } if (bufsiz <= 0) { - PyMem_Free(buffer); return PyErr_NoMemory(); + PyMem_Free(buffer); Py_DECREF(columns); Py_DECREF(iter_row); + return PyErr_NoMemory(); } } + Py_DECREF(columns); + *bufpt++ = '\n'; *bufpt = '\0'; /* sends data */ if (PQputline(self->cnx, buffer)) { PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); PQendcopy(self->cnx); - PyMem_Free(buffer); + PyMem_Free(buffer); Py_DECREF(iter_row); return NULL; } } + Py_DECREF(iter_row); + if (PyErr_Occurred()) { + PyMem_Free(buffer); return NULL; /* pass the iteration error */ + } + /* ends query */ if (PQputline(self->cnx, "\\.\n")) { PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index b51a20f8..b62e6dc5 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1890,15 +1890,42 @@ def testInserttableFromTupleOfLists(self): self.c.inserttable('test', data) self.assertEqual(self.get_back(), self.data) - def testInserttableFromSetofTuples(self): - data = {row for row in self.data} + def testInserttableWithDifferentRowSizes(self): + data = self.data[:-1] + [self.data[-1][:-1]] try: self.c.inserttable('test', data) except TypeError as e: r = str(e) else: r = 'this is fine' - self.assertIn('list or a tuple as second argument', r) + self.assertIn('second arg must contain sequences of the same size', r) + + def testInserttableFromSetofTuples(self): + data = {row for row in self.data} + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), self.data) + + def testInserttableFromDictAsInterable(self): + data = {row: None for row in self.data} + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), self.data) + + def testInserttableFromDictKeys(self): + data = {row: None for row in self.data} + keys = data.keys() + self.c.inserttable('test', keys) + self.assertEqual(self.get_back(), self.data) + + def testInserttableFromDictValues(self): + data = {i: row for i, row in enumerate(self.data)} + values = data.values() + self.c.inserttable('test', values) + self.assertEqual(self.get_back(), self.data) + + def testInserttableFromGeneratorOfTuples(self): + data = (row for row in self.data) + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), self.data) def testInserttableFromListOfSets(self): data = [set(row) for row in self.data] @@ -1908,7 +1935,7 @@ def testInserttableFromListOfSets(self): r = str(e) else: r = 'this is fine' - self.assertIn('second argument must contain a tuple or a list', r) + self.assertIn('second argument must contain tuples or lists', r) def testInserttableMultipleRows(self): num_rows = 100 @@ -2078,7 +2105,7 @@ def testInserttableNoEncoding(self): def testInserttableTooLargeColumnSpecification(self): # should catch buffer overflow when building the column specification self.assertRaises(MemoryError, self.c.inserttable, - 'test', [], ['very_long_column_name'] * 1000) + 'test', self.data, ['very_long_column_name'] * 1000) class TestDirectSocketAccess(unittest.TestCase): From b29663a05d3d799d74a0f360ee4e7498ae2a7ed0 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 29 Jan 2022 18:28:40 +0100 Subject: [PATCH 074/194] Fix error message in inserttable() --- pgconn.c | 2 +- tests/test_classic_connection.py | 24 +++++++++++++++++------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pgconn.c b/pgconn.c index 5bdee3a7..c4da45fb 100644 --- a/pgconn.c +++ b/pgconn.c @@ -730,7 +730,7 @@ conn_inserttable(connObject *self, PyObject *args) PyErr_SetString( PyExc_TypeError, "Method inserttable() expects a tuple or a list" - " as second argument"); + " as third argument"); return NULL; } diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index b62e6dc5..8ae9ce4f 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1980,7 +1980,7 @@ def testInserttableWithInvalidTableName(self): data = [(42,)] # check that the table name is not inserted unescaped # (this would pass otherwise since there is a column named i4) - self.assertRaises(Exception, self.c.inserttable, 'test (i4)', data) + self.assertRaises(OSError, self.c.inserttable, 'test (i4)', data) # make sure that it works if parameters are passed properly self.c.inserttable('test', data, ['i4']) @@ -1989,10 +1989,25 @@ def testInserttableWithInvalidColumnName(self): # check that the column names are not inserted unescaped # (this would pass otherwise since there are columns i2 and i4) self.assertRaises( - Exception, self.c.inserttable, 'test', data, ['i2,i4']) + TypeError, self.c.inserttable, 'test', data, ['i2,i4']) # make sure that it works if parameters are passed properly self.c.inserttable('test', data, ['i2', 'i4']) + def testInserttableWithInvalidColumList(self): + data = self.data + try: + self.c.inserttable('test', data, 'invalid') + except TypeError as e: + r = str(e) + else: + r = 'this is fine' + self.assertIn('expects a tuple or a list as third argument', r) + + def testInserttableWithHugeListOfColumnNames(self): + # should catch buffer overflow when building the column specification + self.assertRaises(MemoryError, self.c.inserttable, + 'test', self.data, ['very_long_column_name'] * 1000) + def testInserttableMaxValues(self): data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), True, '2999-12-31', '11:59:59', 1e99, @@ -2102,11 +2117,6 @@ def testInserttableNoEncoding(self): # cannot encode non-ascii unicode without a specific encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) - def testInserttableTooLargeColumnSpecification(self): - # should catch buffer overflow when building the column specification - self.assertRaises(MemoryError, self.c.inserttable, - 'test', self.data, ['very_long_column_name'] * 1000) - class TestDirectSocketAccess(unittest.TestCase): """Test copy command with direct socket access.""" From bbf83caad35fe5ad7f9fb60640e3a131f89990a4 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 29 Jan 2022 18:29:50 +0100 Subject: [PATCH 075/194] Ignore docker files used for development --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index f826aa80..bd57f86b 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,8 @@ _build_doctrees/ /local/ /tests/LOCAL_*.py +docker-compose.yml +Dockerfile Vagrantfile .coverage From a4eca6247b8af5a17905e1a79c2a4599ba87701b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 29 Jan 2022 18:58:43 +0100 Subject: [PATCH 076/194] Add test for inserttable() with data from query --- tests/test_classic_connection.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 8ae9ce4f..59de5c89 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1776,7 +1776,7 @@ def setUpClass(cls): c.query("drop table if exists test cascade") c.query("create table test (" "i2 smallint, i4 integer, i8 bigint," - " b boolean, dt date, ti time," + "b boolean, dt date, ti time," "d numeric, f4 real, f8 double precision, m money," "c char(1), v4 varchar(4), c4 char(4), t text)") # Check whether the test database uses SQL_ASCII - this means @@ -2117,6 +2117,17 @@ def testInserttableNoEncoding(self): # cannot encode non-ascii unicode without a specific encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) + def testInserttableFromQuery(self): + data = self.c.query( + "select 2::int2 as i2, 4::int4 as i4, 8::int8 as i8, true as b," + "null as dt, null as ti, null as d," + "4.5::float as float4, 8.5::float8 as f8," + "null as m, 'c' as c, 'v4' as v4, null as c4, 'text' as text") + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), [ + (2, 4, 8, True, None, None, None, 4.5, 8.5, + None, 'c', 'v4', None, 'text')]) + class TestDirectSocketAccess(unittest.TestCase): """Test copy command with direct socket access.""" From 3c965127a1a4750393e87f7701fc44de2d14fe14 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 29 Jan 2022 19:07:48 +0100 Subject: [PATCH 077/194] Call PQendcopy in inserttable also in case of an error (#60) --- pgconn.c | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pgconn.c b/pgconn.c index c4da45fb..a33c881d 100644 --- a/pgconn.c +++ b/pgconn.c @@ -815,7 +815,7 @@ conn_inserttable(connObject *self, PyObject *args) if (!(columns = PyIter_Next(iter_row))) break; if (!(PyTuple_Check(columns) || PyList_Check(columns))) { - PyMem_Free(buffer); + PQendcopy(self->cnx); PyMem_Free(buffer); Py_DECREF(columns); Py_DECREF(columns); Py_DECREF(iter_row); PyErr_SetString( PyExc_TypeError, @@ -827,7 +827,7 @@ conn_inserttable(connObject *self, PyObject *args) if (n < 0) { n = j; } else if (j != n) { - PyMem_Free(buffer); + PQendcopy(self->cnx); PyMem_Free(buffer); Py_DECREF(columns); Py_DECREF(iter_row); PyErr_SetString( PyExc_TypeError, @@ -869,7 +869,7 @@ conn_inserttable(connObject *self, PyObject *args) else if (PyUnicode_Check(item)) { PyObject *s = get_encoded_string(item, encoding); if (!s) { - PyMem_Free(buffer); + PQendcopy(self->cnx); PyMem_Free(buffer); Py_DECREF(item); Py_DECREF(columns); Py_DECREF(iter_row); return NULL; /* pass the UnicodeEncodeError */ } @@ -910,7 +910,8 @@ conn_inserttable(connObject *self, PyObject *args) } if (bufsiz <= 0) { - PyMem_Free(buffer); Py_DECREF(columns); Py_DECREF(iter_row); + PQendcopy(self->cnx); PyMem_Free(buffer); + Py_DECREF(columns); Py_DECREF(iter_row); return PyErr_NoMemory(); } @@ -923,22 +924,21 @@ conn_inserttable(connObject *self, PyObject *args) /* sends data */ if (PQputline(self->cnx, buffer)) { PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); - PQendcopy(self->cnx); - PyMem_Free(buffer); Py_DECREF(iter_row); + PQendcopy(self->cnx); PyMem_Free(buffer); Py_DECREF(iter_row); return NULL; } } Py_DECREF(iter_row); if (PyErr_Occurred()) { - PyMem_Free(buffer); return NULL; /* pass the iteration error */ + PQendcopy(self->cnx); PyMem_Free(buffer); + return NULL; /* pass the iteration error */ } /* ends query */ if (PQputline(self->cnx, "\\.\n")) { PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); - PQendcopy(self->cnx); - PyMem_Free(buffer); + PQendcopy(self->cnx); PyMem_Free(buffer); return NULL; } From 5a7684b57982cb58e6470c5516b186f9b277830b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 29 Jan 2022 19:24:44 +0100 Subject: [PATCH 078/194] inserttable() had insufficient check of result (#62) --- pgconn.c | 2 +- tests/test_classic_connection.py | 34 +++++++++++++++++++++----------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/pgconn.c b/pgconn.c index a33c881d..c344566f 100644 --- a/pgconn.c +++ b/pgconn.c @@ -801,7 +801,7 @@ conn_inserttable(connObject *self, PyObject *args) result = PQexec(self->cnx, buffer); Py_END_ALLOW_THREADS - if (!result) { + if (!result || PQresultStatus(result) != PGRES_COPY_IN) { PyMem_Free(buffer); Py_DECREF(iter_row); PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); return NULL; diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 59de5c89..48381815 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1895,10 +1895,10 @@ def testInserttableWithDifferentRowSizes(self): try: self.c.inserttable('test', data) except TypeError as e: - r = str(e) + self.assertIn( + 'second arg must contain sequences of the same size', str(e)) else: - r = 'this is fine' - self.assertIn('second arg must contain sequences of the same size', r) + self.assertFalse('expected an error') def testInserttableFromSetofTuples(self): data = {row for row in self.data} @@ -1932,10 +1932,10 @@ def testInserttableFromListOfSets(self): try: self.c.inserttable('test', data) except TypeError as e: - r = str(e) + self.assertIn( + 'second argument must contain tuples or lists', str(e)) else: - r = 'this is fine' - self.assertIn('second argument must contain tuples or lists', r) + self.assertFalse('expected an error') def testInserttableMultipleRows(self): num_rows = 100 @@ -1980,7 +1980,12 @@ def testInserttableWithInvalidTableName(self): data = [(42,)] # check that the table name is not inserted unescaped # (this would pass otherwise since there is a column named i4) - self.assertRaises(OSError, self.c.inserttable, 'test (i4)', data) + try: + self.c.inserttable('test (i4)', data) + except ValueError as e: + self.assertIn('relation "test (i4)" does not exist', str(e)) + else: + self.assertFalse('expected an error') # make sure that it works if parameters are passed properly self.c.inserttable('test', data, ['i4']) @@ -1988,8 +1993,13 @@ def testInserttableWithInvalidColumnName(self): data = [(2, 4)] # check that the column names are not inserted unescaped # (this would pass otherwise since there are columns i2 and i4) - self.assertRaises( - TypeError, self.c.inserttable, 'test', data, ['i2,i4']) + try: + self.c.inserttable('test', data, ['i2,i4']) + except ValueError as e: + self.assertIn( + 'column "i2,i4" of relation "test" does not exist', str(e)) + else: + self.assertFalse('expected an error') # make sure that it works if parameters are passed properly self.c.inserttable('test', data, ['i2', 'i4']) @@ -1998,10 +2008,10 @@ def testInserttableWithInvalidColumList(self): try: self.c.inserttable('test', data, 'invalid') except TypeError as e: - r = str(e) + self.assertIn( + 'expects a tuple or a list as third argument', str(e)) else: - r = 'this is fine' - self.assertIn('expects a tuple or a list as third argument', r) + self.assertFalse('expected an error') def testInserttableWithHugeListOfColumnNames(self): # should catch buffer overflow when building the column specification From a3f93dac7e8e88456ec88fb7764e8c2c13197856 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 29 Jan 2022 20:46:12 +0100 Subject: [PATCH 079/194] Bump version and update changelog --- .bumpversion.cfg | 2 +- LICENSE.txt | 2 +- docs/about.txt | 6 +++--- docs/announce.rst | 8 ++++---- docs/conf.py | 4 ++-- docs/contents/changelog.rst | 9 +++++++++ docs/contents/install.rst | 2 +- docs/contents/pg/adaptation.rst | 2 +- docs/copyright.rst | 2 +- pg.py | 2 +- pgconn.c | 2 +- pgdb.py | 2 +- pginternal.c | 2 +- pglarge.c | 2 +- pgmodule.c | 2 +- pgnotice.c | 2 +- pgquery.c | 2 +- pgsource.c | 2 +- setup.py | 11 ++++++----- tests/test_classic_connection.py | 2 +- tests/test_classic_dbwrapper.py | 2 +- tests/test_classic_functions.py | 2 +- tox.ini | 4 ++-- 23 files changed, 43 insertions(+), 33 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 24ad1614..8715d670 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 5.2.2 +current_version = 5.2.3 commit = False tag = False diff --git a/LICENSE.txt b/LICENSE.txt index 4ff09c11..c10a5870 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -6,7 +6,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain -Further modifications copyright (c) 2009-2020 by the PyGreSQL Development Team +Further modifications copyright (c) 2009-2022 by the PyGreSQL Development Team PyGreSQL is released under the PostgreSQL License, a liberal Open Source license, similar to the BSD or MIT licenses: diff --git a/docs/about.txt b/docs/about.txt index bff0af4f..9379b9a6 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -5,7 +5,7 @@ PostgreSQL features from a Python script. | This software is copyright © 1995, Pascal Andre. | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2020 by the PyGreSQL team. + | Further modifications are copyright © 2009-2022 by the PyGreSQL team. | For licensing details, see the full :doc:`copyright`. **PostgreSQL** is a highly scalable, SQL compliant, open source @@ -36,7 +36,7 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.2.2 needs PostgreSQL 9.0 to 9.6 or 10 to 13, and -Python 2.7 or 3.5 to 3.9. If you need to support older PostgreSQL versions or +The current version PyGreSQL 5.2.3 needs PostgreSQL 9.0 to 9.6 or 10 to 14, and +Python 2.7 or 3.5 to 3.10. If you need to support older PostgreSQL versions or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index 6fc4ac49..9db1c62d 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -3,10 +3,10 @@ PyGreSQL Announcements ====================== --------------------------------- -Release of PyGreSQL version 5.2.2 +Release of PyGreSQL version 5.2.3 --------------------------------- -Release 5.2.2 of PyGreSQL. +Release 5.2.3 of PyGreSQL. It is available at: https://pypi.org/project/PyGreSQL/. @@ -22,8 +22,8 @@ This version has been built and unit tested on: - openSUSE - Ubuntu - Windows 7 and 10 with both MinGW and Visual Studio - - PostgreSQL 9.0 to 9.6 and 10 to 13 (32 and 64bit) - - Python 2.7 and 3.5 to 3.9 (32 and 64bit) + - PostgreSQL 9.0 to 9.6 and 10 to 14 (32 and 64bit) + - Python 2.7 and 3.5 to 3.10 (32 and 64bit) | D'Arcy J.M. Cain | darcy@PyGreSQL.org diff --git a/docs/conf.py b/docs/conf.py index b8eafe7d..721c6bf8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -61,14 +61,14 @@ # General information about the project. project = 'PyGreSQL' author = 'The PyGreSQL team' -copyright = '2020, ' + author +copyright = '2022, ' + author # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The full version, including alpha/beta/rc tags. -version = release = '5.2.2' +version = release = '5.2.3' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 6c2ba061..19fdaa83 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -3,6 +3,15 @@ ChangeLog Version 5.2.3 (to be released) ------------------------------ +- This version officially supports the new Python 3.10 and PostgreSQL 14. +- Some improvements and fixes in the `inserttable()` method of the `pg` module: + - Sync with `PQendcopy()` when there was an error (#60) + - Improved check for internal result (#62) + - Catch buffer overflows when building the copy command + - Data can now be passed as an iterable, not just list or tuple (#66) +- Some more fixes in the `pg` module: + - Fix argument handling of `is/set_non_blocking()`. + - Add missing `get/set_typecasts` in list of exports. - Fixed a reference counting issue when casting JSON columns (#57). Version 5.2.2 (2020-12-09) diff --git a/docs/contents/install.rst b/docs/contents/install.rst index bc9897c7..1b7ef55e 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -14,7 +14,7 @@ On Windows, you also need to make sure that the directory that contains ``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -2.7 and 3.5 to 3.9, and PostgreSQL versions 9.0 to 9.6 and 10 to 13. +2.7 and 3.5 to 3.10, and PostgreSQL versions 9.0 to 9.6 and 10 to 14. PyGreSQL will be installed as three modules, a shared library called ``_pg.so`` (on Linux) or a DLL called ``_pg.pyd`` (on Windows), and two pure diff --git a/docs/contents/pg/adaptation.rst b/docs/contents/pg/adaptation.rst index b1ada9bd..1cf44418 100644 --- a/docs/contents/pg/adaptation.rst +++ b/docs/contents/pg/adaptation.rst @@ -363,7 +363,7 @@ With PostgreSQL we can easily calculate that these two circles overlap:: True However, calculating the intersection points between the two circles using the -``#`` operator does not work (at least not as of PostgreSQL version 13). +``#`` operator does not work (at least not as of PostgreSQL version 14). So let's resort to SymPy to find out. To ease importing circles from PostgreSQL to SymPy, we create and register the following typecast function:: diff --git a/docs/copyright.rst b/docs/copyright.rst index 77d9ef83..4c9aacc6 100644 --- a/docs/copyright.rst +++ b/docs/copyright.rst @@ -10,7 +10,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain (darcy@PyGreSQL.org) -Further modifications copyright (c) 2009-2020 by the PyGreSQL team. +Further modifications copyright (c) 2009-2022 by the PyGreSQL team. Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/pg.py b/pg.py index 6c1bd35c..81b64a72 100644 --- a/pg.py +++ b/pg.py @@ -4,7 +4,7 @@ # # This file contains the classic pg module. # -# Copyright (c) 2020 by the PyGreSQL Development Team +# Copyright (c) 2022 by the PyGreSQL Development Team # # The notification handler is based on pgnotify which is # Copyright (c) 2001 Ng Pheng Siong. All rights reserved. diff --git a/pgconn.c b/pgconn.c index c344566f..534c86ce 100644 --- a/pgconn.c +++ b/pgconn.c @@ -3,7 +3,7 @@ * * The connection object - this file is part a of the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2022 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgdb.py b/pgdb.py index 3eee8de1..6f5116d0 100644 --- a/pgdb.py +++ b/pgdb.py @@ -4,7 +4,7 @@ # # This file contains the DB-API 2 compatible pgdb module. # -# Copyright (c) 2020 by the PyGreSQL Development Team +# Copyright (c) 2022 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions. diff --git a/pginternal.c b/pginternal.c index 25e1dcc8..91c565be 100644 --- a/pginternal.c +++ b/pginternal.c @@ -3,7 +3,7 @@ * * Internal functions - this file is part a of the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2022 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pglarge.c b/pglarge.c index f2c3a63e..ed8f1824 100644 --- a/pglarge.c +++ b/pglarge.c @@ -3,7 +3,7 @@ * * Large object support - this file is part a of the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2022 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgmodule.c b/pgmodule.c index ee4e101f..3096fdd9 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -3,7 +3,7 @@ * * This is the main file for the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2022 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgnotice.c b/pgnotice.c index 7f0c0cc4..7e5b93c7 100644 --- a/pgnotice.c +++ b/pgnotice.c @@ -3,7 +3,7 @@ * * The notice object - this file is part a of the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2022 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgquery.c b/pgquery.c index 63b86c32..a04eb68b 100644 --- a/pgquery.c +++ b/pgquery.c @@ -3,7 +3,7 @@ * * The query object - this file is part a of the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2022 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgsource.c b/pgsource.c index 2311e2a0..4fa04365 100644 --- a/pgsource.c +++ b/pgsource.c @@ -3,7 +3,7 @@ * * The source object - this file is part a of the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2022 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/setup.py b/setup.py index ed02c404..bc1de77c 100755 --- a/setup.py +++ b/setup.py @@ -2,11 +2,11 @@ # # PyGreSQL - a Python interface for the PostgreSQL database. # -# Copyright (c) 2020 by the PyGreSQL Development Team +# Copyright (c) 2022 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions. -"""Setup script for PyGreSQL version 5.2.2 +"""Setup script for PyGreSQL version 5.2.3 PyGreSQL is an open-source Python module that interfaces to a PostgreSQL database. It embeds the PostgreSQL query library to allow @@ -26,8 +26,8 @@ * PostgreSQL pg_config tool (usually included in the devel package) (the Windows installer has it as part of the database server feature) -PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.9, -and PostgreSQL versions 9.0 to 9.6 and 10 to 13. +PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.10, +and PostgreSQL versions 9.0 to 9.6 and 10 to 14. Use as follows: python setup.py build_ext # to build the module @@ -52,7 +52,7 @@ from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib -version = '5.2.2' +version = '5.2.3' if not (sys.version_info[:2] == (2, 7) or (3, 5) <= sys.version_info[:2] < (4, 0)): @@ -251,6 +251,7 @@ def finalize_options(self): 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', "Programming Language :: SQL", "Topic :: Database", "Topic :: Database :: Front-Ends", diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 48381815..c99cfa21 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -208,7 +208,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 140000) + self.assertTrue(90000 <= server_version < 150000) def testAttributeSocket(self): socket = self.connection.socket diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 43dffaf5..90e4d619 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -265,7 +265,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 140000) + self.assertTrue(90000 <= server_version < 150000) self.assertEqual(server_version, self.db.db.server_version) def testAttributeSocket(self): diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index e59828cc..d7c7a720 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -126,7 +126,7 @@ def testPqlibVersion(self): v = pg.get_pqlib_version() self.assertIsInstance(v, long) self.assertGreater(v, 90000) - self.assertLess(v, 140000) + self.assertLess(v, 150000) class TestParseArray(unittest.TestCase): diff --git a/tox.ini b/tox.ini index fa945503..706a54a1 100644 --- a/tox.ini +++ b/tox.ini @@ -5,14 +5,14 @@ envlist = py{27,35,36,37,38,39,310},flake8,docs [testenv:flake8] basepython = python3.9 -deps = flake8>=3.8,<4 +deps = flake8>=4,<5 commands = flake8 setup.py pg.py pgdb.py tests [testenv:docs] basepython = python3.9 deps = - sphinx>=3.2,<4 + sphinx>=3.5,<4 cloud_sptheme>=1.10,<2 commands = sphinx-build -b html -nEW docs docs/_build/html From c3b54d5f8e87612df89b934459359d401ab68188 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 29 Jan 2022 20:48:36 +0100 Subject: [PATCH 080/194] Ignore shared libraries produced during builds --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index bd57f86b..67e3ae35 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,14 @@ *~ *.bak *.cache +*.dll *.egg-info *.log *.patch *.pid *.pstats *.py[co] +*.so *.swp __pycache__/ From ac66e2e9729c8efc4d594b5a122029ea04b297e2 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 29 Jan 2022 21:31:02 +0100 Subject: [PATCH 081/194] Add test for copying a table using inserttable --- tests/test_classic_dbwrapper.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 90e4d619..f2618a23 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4227,6 +4227,20 @@ def testNotificationHandler(self): self.db.reopen() self.assertIsNone(handler.db) + def testInserttableFromQuery(self): + # use inserttable() to copy from one table to another + query = self.db.query + self.createTable('test_table_from', 'n integer, t timestamp') + self.createTable('test_table_to', 'n integer, t timestamp') + for i in range(1, 4): + query("insert into test_table_from values ($1, now())", i) + self.db.inserttable( + 'test_table_to', query("select n, t::text from test_table_from")) + data_from = query("select * from test_table_from").getresult() + data_to = query("select * from test_table_to").getresult() + self.assertEqual([row[0] for row in data_from], [1, 2, 3]) + self.assertEqual(data_from, data_to) + class TestDBClassNonStdOpts(TestDBClass): """Test the methods of the DB class with non-standard global options.""" From a40443c4ac7c56b76bf9b662a24e1a32c4cec8e6 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 29 Jan 2022 22:36:43 +0100 Subject: [PATCH 082/194] Add GitHub worflows for testing Also improve the configurability of the test database. --- .../workflows/{release-docs.yml => docs.yml} | 0 .github/workflows/lint.yml | 24 +++++++ .github/workflows/tests.yml | 67 +++++++++++++++++++ tests/config.py | 28 ++++++++ tests/test_classic.py | 16 +---- tests/test_classic_connection.py | 25 ++----- tests/test_classic_dbwrapper.py | 17 +---- tests/test_classic_largeobj.py | 17 +---- tests/test_classic_notification.py | 9 +-- tests/test_dbapi20.py | 17 +---- tests/test_dbapi20_copy.py | 20 ++---- tests/test_tutorial.py | 20 ++---- tox.ini | 4 +- 13 files changed, 148 insertions(+), 116 deletions(-) rename .github/workflows/{release-docs.yml => docs.yml} (100%) create mode 100644 .github/workflows/lint.yml create mode 100644 .github/workflows/tests.yml create mode 100644 tests/config.py diff --git a/.github/workflows/release-docs.yml b/.github/workflows/docs.yml similarity index 100% rename from .github/workflows/release-docs.yml rename to .github/workflows/docs.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..437449a1 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,24 @@ +name: Run PyGreSQL quality checks + +on: + push: + pull_request: + +jobs: + checks: + name: Quality checks run + runs-on: ubuntu-20.04 + + strategy: + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Install tox + run: pip install tox + - uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Run quality checks + run: tox -e flake8,docs + timeout-minutes: 5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..b6eeecd1 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,67 @@ +name: Run PyGreSQL test matrix + +# this has been shamelessly copied from Psycopg + +on: + push: + pull_request: + +jobs: + tests: + name: Unit tests run + runs-on: ubuntu-18.04 + + strategy: + fail-fast: false + matrix: + include: + - {python: "2.7", postgres: "9.3"} + - {python: "3.5", postgres: "9.6"} + - {python: "3.6", postgres: "10"} + - {python: "3.7", postgres: "11"} + - {python: "3.8", postgres: "12"} + - {python: "3.9", postgres: "13"} + - {python: "3.10", postgres: "14"} + + # Opposite extremes of the supported Py/PG range, other architecture + - {python: "2.7", postgres: "14", architecture: "x86"} + - {python: "3.5", postgres: "13", architecture: "x86"} + - {python: "3.6", postgres: "12", architecture: "x86"} + - {python: "3.7", postgres: "11", architecture: "x86"} + - {python: "3.8", postgres: "10", architecture: "x86"} + - {python: "3.9", postgres: "9.6", architecture: "x86"} + - {python: "3.10", postgres: "9.3", architecture: "x86"} + + env: + PYGRESQL_DB: test + PYGRESQL_HOST: 127.0.0.1 + PYGRESQL_USER: test + PYGRESQL_PASSWD: test + + services: + postgresql: + image: postgres:${{ matrix.postgres }} + env: + POSTGRES_USER: test + POSTGRES_PASSWORD: test + ports: + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - uses: actions/checkout@v2 + - name: Install tox + run: pip install tox + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + - name: Run tests + env: + MATRIX_PYTHON: ${{ matrix.python }} + run: tox -e py${MATRIX_PYTHON/./} + timeout-minutes: 5 diff --git a/tests/config.py b/tests/config.py new file mode 100644 index 00000000..a6082593 --- /dev/null +++ b/tests/config.py @@ -0,0 +1,28 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +from os import environ + +# We need a database to test against. +# If LOCAL_PyGreSQL.py exists, we will get our information from that. +# Otherwise, we use the defaults. + +# The tests should be run with various PostgreSQL versions and databases +# created with different encodings and locales. Particularly, make sure the +# tests are running against databases created with both SQL_ASCII and UTF8. + +# The current user must have create schema privilege on the database. + +dbname = environ.get('PYGRESQL_DB', 'unittest') +dbhost = environ.get('PYGRESQL_HOST', None) +dbport = environ.get('PYGRESQL_PORT', 5432) +dbuser = environ.get('PYGRESQL_USER', None) +dbpasswd = environ.get('PYGRESQL_PASSWD', None) + +try: + from .LOCAL_PyGreSQL import * # noqa: F401 +except (ImportError, ValueError): + try: + from LOCAL_PyGreSQL import * # noqa: F401 + except ImportError: + pass diff --git a/tests/test_classic.py b/tests/test_classic.py index c3b731d1..727e4a86 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -11,23 +11,11 @@ from pg import * -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -dbname = 'unittest' -dbhost = None -dbport = 5432 - -try: - from .LOCAL_PyGreSQL import * # noqa: F401 -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * # noqa: F401 - except ImportError: - pass +from .config import dbname, dbhost, dbport, dbuser, dbpasswd def open_db(): - db = DB(dbname, dbhost, dbport) + db = DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) db.query("SET DATESTYLE TO 'ISO'") db.query("SET TIME ZONE 'EST5EDT'") db.query("SET DEFAULT_WITH_OIDS=FALSE") diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index c99cfa21..6f5914a3 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -27,22 +27,7 @@ import pg # the module under test -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -# These tests should be run with various PostgreSQL versions and databases -# created with different encodings and locales. Particularly, make sure the -# tests are running against databases created with both SQL_ASCII and UTF8. -dbname = 'unittest' -dbhost = None -dbport = 5432 - -try: - from .LOCAL_PyGreSQL import * # noqa: F401 -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * # noqa: F401 - except ImportError: - pass +from .config import dbname, dbhost, dbport, dbuser, dbpasswd try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences long @@ -67,7 +52,8 @@ def connect(): """Create a basic pg connection to the test database.""" # noinspection PyArgumentList - connection = pg.connect(dbname, dbhost, dbport) + connection = pg.connect(dbname, dbhost, dbport, + user=dbuser, passwd=dbpasswd) connection.query("set client_min_messages=warning") return connection @@ -75,7 +61,8 @@ def connect(): def connect_nowait(): """Start a basic pg connection in a non-blocking manner.""" # noinspection PyArgumentList - return pg.connect(dbname, dbhost, dbport, nowait=True) + return pg.connect(dbname, dbhost, dbport, + user=dbuser, passwd=dbpasswd, nowait=True) class TestCanConnect(unittest.TestCase): @@ -102,7 +89,7 @@ def testCanConnectNoWait(self): try: connection = connect_nowait() rc = connection.poll() - self.assertEqual(rc, pg.POLLING_READING) + self.assertIn(rc, (pg.POLLING_READING, pg.POLLING_WRITING)) while rc not in (pg.POLLING_OK, pg.POLLING_FAILED): rc = connection.poll() except pg.Error as error: diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index f2618a23..f0c02a56 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -26,23 +26,10 @@ from time import strftime from operator import itemgetter -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -# The current user must have create schema privilege on the database. -dbname = 'unittest' -dbhost = None -dbport = 5432 +from .config import dbname, dbhost, dbport, dbuser, dbpasswd debug = False # let DB wrapper print debugging output -try: - from .LOCAL_PyGreSQL import * # noqa: F401 -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * # noqa: F401 - except ImportError: - pass - try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences long except NameError: # Python >= 3.0 @@ -68,7 +55,7 @@ def DB(): """Create a DB wrapper object connecting to the test database.""" - db = pg.DB(dbname, dbhost, dbport) + db = pg.DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) if debug: db.debug = debug db.query("set client_min_messages=warning") diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index b82d56fa..fc0464d5 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -16,19 +16,7 @@ import pg # the module under test -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -dbname = 'unittest' -dbhost = None -dbport = 5432 - -try: - from .LOCAL_PyGreSQL import * # noqa: F401 -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * # noqa: F401 - except ImportError: - pass +from .config import dbname, dbhost, dbport, dbuser, dbpasswd windows = os.name == 'nt' @@ -36,7 +24,8 @@ # noinspection PyArgumentList def connect(): """Create a basic pg connection to the test database.""" - connection = pg.connect(dbname, dbhost, dbport) + connection = pg.connect(dbname, dbhost, dbport, + user=dbuser, passwd=dbpasswd) connection.query("set client_min_messages=warning") return connection diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 8a8d86d9..da2a8fa4 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -17,12 +17,7 @@ import pg # the module under test -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -# The current user must have create schema privilege on the database. -dbname = 'unittest' -dbhost = None -dbport = 5432 +from .config import dbname, dbhost, dbport, dbuser, dbpasswd debug = False # let DB wrapper print debugging output @@ -37,7 +32,7 @@ def DB(): """Create a DB wrapper object connecting to the test database.""" - db = pg.DB(dbname, dbhost, dbport) + db = pg.DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) if debug: db.debug = debug return db diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 3dad0d7a..94b7ab73 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -16,19 +16,7 @@ # noinspection PyUnresolvedReferences import dbapi20 -# We need a database to test against. -# If LOCAL_PyGreSQL.py exists we will get our information from that. -# Otherwise we use the defaults. -dbname = 'dbapi20_test' -dbhost = '' -dbport = 5432 -try: - from .LOCAL_PyGreSQL import * # noqa: F401 -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * # noqa: F401 - except ImportError: - pass +from .config import dbname, dbhost, dbport, dbuser, dbpasswd try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences long @@ -51,7 +39,8 @@ class test_PyGreSQL(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () connect_kw_args = { - 'database': dbname, 'host': '%s:%d' % (dbhost or '', dbport or -1)} + 'database': dbname, 'host': '%s:%d' % (dbhost or '', dbport or -1), + 'user': dbuser, 'password': dbpasswd} lower_func = 'lower' # For stored procedure test diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index dbe25fd1..47fc012a 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -19,20 +19,7 @@ import pgdb # the module under test -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -# The current user must have create schema privilege on the database. -dbname = 'unittest' -dbhost = None -dbport = 5432 - -try: - from .LOCAL_PyGreSQL import * # noqa: F401 -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * # noqa: F401 - except ImportError: - pass +from .config import dbname, dbhost, dbport, dbuser, dbpasswd try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences unicode @@ -126,8 +113,9 @@ class TestCopy(unittest.TestCase): @staticmethod def connect(): - return pgdb.connect( - database=dbname, host='%s:%d' % (dbhost or '', dbport or -1)) + host = '%s:%d' % (dbhost or '', dbport or -1) + return pgdb.connect(database=dbname, host=host, + user=dbuser, password=dbpasswd) @classmethod def setUpClass(cls): diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 94871ecd..dbc93024 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -8,19 +8,7 @@ from pg import DB from pgdb import connect -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -dbname = 'unittest' -dbhost = None -dbport = 5432 - -try: - from .LOCAL_PyGreSQL import * # noqa: F401 -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * # noqa: F401 - except ImportError: - pass +from .config import dbname, dbhost, dbport, dbuser, dbpasswd class TestClassicTutorial(unittest.TestCase): @@ -28,7 +16,7 @@ class TestClassicTutorial(unittest.TestCase): def setUp(self): """Setup test tables or empty them if they already exist.""" - db = DB(dbname=dbname, host=dbhost, port=dbport) + db = DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) db.query("set datestyle to 'iso'") db.query("set default_with_oids=false") db.query("set standard_conforming_strings=false") @@ -122,9 +110,9 @@ class TestDbApi20Tutorial(unittest.TestCase): def setUp(self): """Setup test tables or empty them if they already exist.""" - database = dbname host = '%s:%d' % (dbhost or '', dbport or -1) - con = connect(database=database, host=host) + con = connect(database=dbname, host=host, + user=dbuser, password=dbpasswd) cur = con.cursor() cur.execute("set datestyle to 'iso'") cur.execute("set default_with_oids=false") diff --git a/tox.ini b/tox.ini index 706a54a1..f9307752 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py{27,35,36,37,38,39,310},flake8,docs +envlist = 27,3{5,6,7,8,9,10},flake8,docs [testenv:flake8] basepython = python3.9 @@ -18,6 +18,8 @@ commands = sphinx-build -b html -nEW docs docs/_build/html [testenv] +passenv = PG* PYGRESQL_* commands = python setup.py clean --all build_ext --force --inplace --strict --ssl-info --memory-size python -m unittest discover {posargs} + From 663dd9040d5c219fe77ab436b9583f0c2488d978 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 30 Jan 2022 00:45:47 +0100 Subject: [PATCH 083/194] Travis CI configuration not needed any more We now use GitHub actions instead. --- .travis.yml | 57 ----------------------------------------------------- 1 file changed, 57 deletions(-) delete mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index dde3e96b..00000000 --- a/.travis.yml +++ /dev/null @@ -1,57 +0,0 @@ -# Travis CI configuration -# see https://docs.travis-ci.com/user/languages/python - -language: python - -matrix: - include: - - name: Code quality tests - env: TOXENV=flake8,docs - python: 3.8 - # - name: Unit tests with Python 3.9 - # env: TOXENV=py39 - # python: 3.9 - - name: Unit tests with Python 3.8 - env: TOXENV=py38 - python: 3.8 - - name: Unit tests with Python 3.7 - env: TOXENV=py37 - python: 3.7 - - name: Unit tests with Python 3.6 - env: TOXENV=py36 - python: 3.6 - - name: Unit tests with Python 3.5 - env: TOXENV=py35 - python: 3.5 - - name: Unit tests with Python 2.7 - env: TOXENV=py27 - python: 2.7 - -cache: - directories: - - "$HOME/.cache/pip" - - "$TRAVIS_BUILD_DIR/.tox" - -install: - - pip install tox-travis - -script: - - tox -e $TOXENV - -addons: - # last PostgreSQL version that still supports OIDs (11) - postgresql: "11" - apt: - packages: - - postgresql-11 - - postgresql-server-dev-11 - -services: - - postgresql - -before_script: - - sudo service postgresql stop - - sudo -u postgres sed -i "s/port = 54[0-9][0-9]/port = 5432/" /etc/postgresql/11/main/postgresql.conf - - sudo service postgresql start 11 - - sudo -u postgres psql -c 'create database unittest' - From 307ec95c40a56d33bf478de72575402f1774981c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 30 Jan 2022 13:23:35 +0100 Subject: [PATCH 084/194] Split dotted table names when escaping in inserttable (#61) This is the same pragmatic solution as used in the copy methods of pgdb. We should implement a proper solution by allowing tuples or a separate schema parameter in the next version. --- docs/contents/changelog.rst | 1 + pgconn.c | 17 +++++++++++++---- tests/test_classic_connection.py | 13 +++++++++++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 19fdaa83..cd5c395f 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -6,6 +6,7 @@ Version 5.2.3 (to be released) - This version officially supports the new Python 3.10 and PostgreSQL 14. - Some improvements and fixes in the `inserttable()` method of the `pg` module: - Sync with `PQendcopy()` when there was an error (#60) + - Allow specifying a schema in the table name (#61) - Improved check for internal result (#62) - Catch buffer overflows when building the copy command - Data can now be passed as an iterable, not just list or tuple (#66) diff --git a/pgconn.c b/pgconn.c index 534c86ce..d1b0d2f8 100644 --- a/pgconn.c +++ b/pgconn.c @@ -690,7 +690,7 @@ static PyObject * conn_inserttable(connObject *self, PyObject *args) { PGresult *result; - char *table, *buffer, *bufpt, *bufmax; + char *table, *buffer, *bufpt, *bufmax, *s, *t; int encoding; size_t bufsiz; PyObject *rows, *iter_row, *item, *columns = NULL; @@ -753,9 +753,18 @@ conn_inserttable(connObject *self, PyObject *args) /* starts query */ bufpt = buffer; bufmax = bufpt + MAX_BUFFER_SIZE; - table = PQescapeIdentifier(self->cnx, table, strlen(table)); - bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "copy %s", table); - PQfreemem(table); + bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "copy "); + + s = table; + do { + t = strchr(s, '.'); if (!t) t = s + strlen(s); + table = PQescapeIdentifier(self->cnx, s, (size_t) (t - s)); + if (bufpt < bufmax) + bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "%s", table); + PQfreemem(table); + s = t; if (*s && bufpt < bufmax) *bufpt++ = *s++; + } while (*s); + if (columns) { /* adds a string like f" ({','.join(columns)})" */ if (bufpt < bufmax) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 6f5914a3..9d30b0ad 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1963,6 +1963,11 @@ def testInserttableOnlyTwoColumns(self): + (None,) * 6 for i in range(20)] self.assertEqual(self.get_back(), data) + def testInserttableWithDottedTableName(self): + data = self.data + self.c.inserttable('public.test', data) + self.assertEqual(self.get_back(), data) + def testInserttableWithInvalidTableName(self): data = [(42,)] # check that the table name is not inserted unescaped @@ -1976,6 +1981,14 @@ def testInserttableWithInvalidTableName(self): # make sure that it works if parameters are passed properly self.c.inserttable('test', data, ['i4']) + def testInserttableWithInvalidDataType(self): + try: + self.c.inserttable('test', 42) + except TypeError as e: + self.assertIn('expects an iterable as second argument', str(e)) + else: + self.assertFalse('expected an error') + def testInserttableWithInvalidColumnName(self): data = [(2, 4)] # check that the column names are not inserted unescaped From 312c08af76a4e64e9a53e5d44c0e26b95cf17618 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 30 Jan 2022 13:57:53 +0100 Subject: [PATCH 085/194] Fix upsert with limited number of columns (#58) --- pg.py | 2 +- tests/test_classic_dbwrapper.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pg.py b/pg.py index 81b64a72..bfb62ab4 100644 --- a/pg.py +++ b/pg.py @@ -2476,7 +2476,7 @@ def upsert(self, table, row=None, **kw): keyname.add('oid') for n in attnames: if n not in keyname: - value = kw.get(n, True) + value = kw.get(n, n in row) if value: if not isinstance(value, basestring): value = 'excluded.%s' % col(n) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index f0c02a56..25ade713 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -2160,6 +2160,12 @@ def testUpsert(self): s = dict(m=3, u='z') r = upsert(table, s, oid='invalid') self.assertIs(r, s) + s = dict(n=2) + # do not modify columns missing in the dict + r = upsert(table, s) + self.assertIs(r, s) + r = query(q).getresult() + self.assertEqual(r, [(1, 'x2'), (2, 'y3')]) def testUpsertWithOids(self): if not self.oids: From 32f910dac938ddc1d1d2c9d644b0363aae6aa399 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 30 Jan 2022 14:01:23 +0100 Subject: [PATCH 086/194] Update changelog --- docs/contents/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index cd5c395f..b813c271 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -11,6 +11,7 @@ Version 5.2.3 (to be released) - Catch buffer overflows when building the copy command - Data can now be passed as an iterable, not just list or tuple (#66) - Some more fixes in the `pg` module: + - Fix upsert with limited number of columns (#58). - Fix argument handling of `is/set_non_blocking()`. - Add missing `get/set_typecasts` in list of exports. - Fixed a reference counting issue when casting JSON columns (#57). From fe2efe9b92540013410b03745714f03530c0d125 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 30 Jan 2022 14:51:01 +0100 Subject: [PATCH 087/194] Use assertion methods in tests consistently --- tests/test_classic_notification.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index da2a8fa4..29e6921d 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -53,10 +53,10 @@ def testPgNotify(self): warnings.simplefilter("always") # noinspection PyDeprecation handler1 = pg.pgnotify(db, *args, **kwargs) - assert len(warn_msgs) == 1 + self.assertEqual(len(warn_msgs), 1) warn_msg = warn_msgs[0] - assert issubclass(warn_msg.category, DeprecationWarning) - assert 'deprecated' in str(warn_msg.message) + self.assertTrue(issubclass(warn_msg.category, DeprecationWarning)) + self.assertIn('deprecated', str(warn_msg.message)) self.assertIsInstance(handler1, pg.NotificationHandler) handler2 = db.notification_handler(*args, **kwargs) self.assertIsInstance(handler2, pg.NotificationHandler) From febc87907eab0b453a635dfad131a37c8446e30d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 30 Jan 2022 15:14:07 +0100 Subject: [PATCH 088/194] Make type resolution independent of search path (#65) --- pg.py | 23 ++++++++++++++--------- pgdb.py | 2 +- tests/test_classic_dbwrapper.py | 7 +++++++ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/pg.py b/pg.py index bfb62ab4..b03c3d71 100644 --- a/pg.py +++ b/pg.py @@ -1270,16 +1270,16 @@ def __init__(self, db): if db.server_version < 80400: # very old remote databases (not officially supported) self._query_pg_type = ( - "SELECT oid, typname, typname::text::regtype," + "SELECT oid, typname, oid::pg_catalog.regtype," " typlen, typtype, null as typcategory, typdelim, typrelid" " FROM pg_catalog.pg_type" - " WHERE oid OPERATOR(pg_catalog.=) %s::regtype") + " WHERE oid OPERATOR(pg_catalog.=) %s::pg_catalog.regtype") else: self._query_pg_type = ( - "SELECT oid, typname, typname::regtype," + "SELECT oid, typname, oid::pg_catalog.regtype," " typlen, typtype, typcategory, typdelim, typrelid" " FROM pg_catalog.pg_type" - " WHERE oid OPERATOR(pg_catalog.=) %s::regtype") + " WHERE oid OPERATOR(pg_catalog.=) %s::pg_catalog.regtype") def add(self, oid, pgtype, regtype, typlen, typtype, category, delim, relid): @@ -1636,22 +1636,26 @@ def __init__(self, *args, **kw): if db.server_version < 80400: # very old remote databases (not officially supported) self._query_attnames = ( - "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype," + "SELECT a.attname," + " t.oid, t.typname, t.oid::pg_catalog.regtype," " t.typlen, t.typtype, null as typcategory," " t.typdelim, t.typrelid" " FROM pg_catalog.pg_attribute a" " JOIN pg_catalog.pg_type t" " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass" + " WHERE a.attrelid OPERATOR(pg_catalog.=)" + " %s::pg_catalog.regclass" " AND %s AND NOT a.attisdropped ORDER BY a.attnum") else: self._query_attnames = ( - "SELECT a.attname, t.oid, t.typname, t.typname::regtype," + "SELECT a.attname," + " t.oid, t.typname, t.oid::pg_catalog.regtype," " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" " FROM pg_catalog.pg_attribute a" " JOIN pg_catalog.pg_type t" " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass" + " WHERE a.attrelid OPERATOR(pg_catalog.=)" + " %s::pg_catalog.regclass" " AND %s AND NOT a.attisdropped ORDER BY a.attnum") db.set_cast_hook(self.dbtypes.typecast) # For debugging scripts, self.debug can be set @@ -2097,7 +2101,8 @@ def pkey(self, table, composite=False, flush=False): " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" " AND NOT a.attisdropped" - " WHERE i.indrelid OPERATOR(pg_catalog.=) %s::regclass" + " WHERE i.indrelid OPERATOR(pg_catalog.=)" + " %s::pg_catalog.regclass" " AND i.indisprimary ORDER BY a.attnum") % ( _quote_if_unqualified('$1', table),) pkey = self.db.query(q, (table,)).getresult() diff --git a/pgdb.py b/pgdb.py index 6f5116d0..7b0eaefc 100644 --- a/pgdb.py +++ b/pgdb.py @@ -795,7 +795,7 @@ def __missing__(self, key): else: if '.' not in key and '"' not in key: key = '"%s"' % (key,) - oid = "'%s'::regtype" % (self._escape_string(key),) + oid = "'%s'::pg_catalog.regtype" % (self._escape_string(key),) try: self._src.execute(self._query_pg_type % (oid,)) except ProgrammingError: diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 25ade713..7e55c246 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4762,6 +4762,13 @@ def testMunging(self): else: self.assertNotIn('oid(t)', r) + def testQeryInformationSchema(self): + q = ("select array_agg(column_name) from information_schema.columns" + " where table_schema in ('s1', 's2', 's3', 's4')") + r = self.db.query(q).onescalar() + self.assertIsInstance(r, list) + self.assertEqual(set(r), set(['d', 'n'] * 8)) + class TestDebug(unittest.TestCase): """Test the debug attribute of the DB class.""" From 17792e1215fe35c5889668c3617cc0af2003d9ed Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 30 Jan 2022 15:25:34 +0100 Subject: [PATCH 089/194] Free memory in case of some errors in inserttable --- pgconn.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgconn.c b/pgconn.c index d1b0d2f8..3c33fbb2 100644 --- a/pgconn.c +++ b/pgconn.c @@ -780,7 +780,7 @@ conn_inserttable(connObject *self, PyObject *args) else if (PyUnicode_Check(obj)) { obj = get_encoded_string(obj, encoding); if (!obj) { - Py_DECREF(iter_row); + PyMem_Free(buffer); Py_DECREF(iter_row); return NULL; /* pass the UnicodeEncodeError */ } PyBytes_AsStringAndSize(obj, &col, &slen); @@ -789,7 +789,7 @@ conn_inserttable(connObject *self, PyObject *args) PyErr_SetString( PyExc_TypeError, "The third argument must contain only strings"); - Py_DECREF(iter_row); + PyMem_Free(buffer); Py_DECREF(iter_row); return NULL; } col = PQescapeIdentifier(self->cnx, col, (size_t) slen); From 2a03fccb7c0ac4c741296ea2d33089c318ea9663 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 30 Jan 2022 15:38:30 +0100 Subject: [PATCH 090/194] Fix unit tests for old database server versions --- tests/test_classic_dbwrapper.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 7e55c246..dfb2eecd 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4763,8 +4763,11 @@ def testMunging(self): self.assertNotIn('oid(t)', r) def testQeryInformationSchema(self): - q = ("select array_agg(column_name) from information_schema.columns" - " where table_schema in ('s1', 's2', 's3', 's4')") + q = "column_name" + if self.db.server_version < 110000: + q += "::text" # old version does not have sql_identifier array + q = "select array_agg(%s) from information_schema.columns" % q + q += " where table_schema in ('s1', 's2', 's3', 's4')" r = self.db.query(q).onescalar() self.assertIsInstance(r, list) self.assertEqual(set(r), set(['d', 'n'] * 8)) From 97cfe05908797e4149dd166b4d1670b25aa49f01 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 30 Jan 2022 20:07:17 +0100 Subject: [PATCH 091/194] Add release date to changelog --- docs/contents/changelog.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index b813c271..c1c1134e 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,8 +1,8 @@ ChangeLog ========= -Version 5.2.3 (to be released) ------------------------------- +Version 5.2.3 (2022-01-30) +-------------------------- - This version officially supports the new Python 3.10 and PostgreSQL 14. - Some improvements and fixes in the `inserttable()` method of the `pg` module: - Sync with `PQendcopy()` when there was an error (#60) From 481a2cac926141e6502a6868c6d97b0107b4a542 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 31 Jan 2022 00:04:57 +0100 Subject: [PATCH 092/194] Fix indentation in the changelog --- docs/contents/changelog.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index c1c1134e..d2591c51 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -5,15 +5,15 @@ Version 5.2.3 (2022-01-30) -------------------------- - This version officially supports the new Python 3.10 and PostgreSQL 14. - Some improvements and fixes in the `inserttable()` method of the `pg` module: - - Sync with `PQendcopy()` when there was an error (#60) - - Allow specifying a schema in the table name (#61) - - Improved check for internal result (#62) - - Catch buffer overflows when building the copy command - - Data can now be passed as an iterable, not just list or tuple (#66) + - Sync with `PQendcopy()` when there was an error (#60) + - Allow specifying a schema in the table name (#61) + - Improved check for internal result (#62) + - Catch buffer overflows when building the copy command + - Data can now be passed as an iterable, not just list or tuple (#66) - Some more fixes in the `pg` module: - - Fix upsert with limited number of columns (#58). - - Fix argument handling of `is/set_non_blocking()`. - - Add missing `get/set_typecasts` in list of exports. + - Fix upsert with limited number of columns (#58). + - Fix argument handling of `is/set_non_blocking()`. + - Add missing `get/set_typecasts` in list of exports. - Fixed a reference counting issue when casting JSON columns (#57). Version 5.2.2 (2020-12-09) From 9304bf9ba4a0dd5e34ee4ab947c816075d2909e6 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 5 Feb 2022 15:19:43 +0100 Subject: [PATCH 093/194] Fix tox envlist --- tox.ini | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index f9307752..d5a72960 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = 27,3{5,6,7,8,9,10},flake8,docs +envlist = py27,py3{5,6,7,8,9,10},flake8,docs [testenv:flake8] basepython = python3.9 @@ -22,4 +22,3 @@ passenv = PG* PYGRESQL_* commands = python setup.py clean --all build_ext --force --inplace --strict --ssl-info --memory-size python -m unittest discover {posargs} - From e27faeca0b28e239fc8f87fed3c96f111032a720 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 5 Feb 2022 15:24:17 +0100 Subject: [PATCH 094/194] inserttable() failed to escape carriage return (#68) Also, follow the recommendation to use the proper escape sequences instead of just preceding the critical characters with a backslash. --- pgconn.c | 72 ++++++++++++++++++++++++++------ tests/test_classic_connection.py | 14 +++++++ 2 files changed, 74 insertions(+), 12 deletions(-) diff --git a/pgconn.c b/pgconn.c index 3c33fbb2..b71d3db3 100644 --- a/pgconn.c +++ b/pgconn.c @@ -868,11 +868,27 @@ conn_inserttable(connObject *self, PyObject *args) const char* t = PyBytes_AsString(item); while (*t && bufsiz) { - if (*t == '\\' || *t == '\t' || *t == '\n') { - *bufpt++ = '\\'; --bufsiz; - if (!bufsiz) break; + switch (*t) { + case '\\': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= '\\'; + break; + case '\t': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= 't'; + break; + case '\r': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= 'r'; + break; + case '\n': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= 'n'; + break; + default: + *bufpt ++= *t; } - *bufpt++ = *t++; --bufsiz; + ++t; --bufsiz; } } else if (PyUnicode_Check(item)) { @@ -886,11 +902,27 @@ conn_inserttable(connObject *self, PyObject *args) const char* t = PyBytes_AsString(s); while (*t && bufsiz) { - if (*t == '\\' || *t == '\t' || *t == '\n') { - *bufpt++ = '\\'; --bufsiz; - if (!bufsiz) break; + switch (*t) { + case '\\': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= '\\'; + break; + case '\t': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= 't'; + break; + case '\r': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= 'r'; + break; + case '\n': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= 'n'; + break; + default: + *bufpt ++= *t; } - *bufpt++ = *t++; --bufsiz; + ++t; --bufsiz; } Py_DECREF(s); } @@ -909,11 +941,27 @@ conn_inserttable(connObject *self, PyObject *args) const char* t = PyStr_AsString(s); while (*t && bufsiz) { - if (*t == '\\' || *t == '\t' || *t == '\n') { - *bufpt++ = '\\'; --bufsiz; - if (!bufsiz) break; + switch (*t) { + case '\\': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= '\\'; + break; + case '\t': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= 't'; + break; + case '\r': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= 'r'; + break; + case '\n': + *bufpt++ = '\\'; + if (--bufsiz) *bufpt ++= 'n'; + break; + default: + *bufpt ++= *t; } - *bufpt++ = *t++; --bufsiz; + ++t; --bufsiz; } Py_DECREF(s); } diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 9d30b0ad..0a83f3e1 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2138,6 +2138,20 @@ def testInserttableFromQuery(self): (2, 4, 8, True, None, None, None, 4.5, 8.5, None, 'c', 'v4', None, 'text')]) + def testInserttableSpecialChars(self): + class S(object): + def __repr__(self): + return s + + s = '1\'2"3\b4\f5\n6\r7\t8\b9\\0' + s1 = s.encode('ascii') if unicode_strings else s.decode('ascii') + s2 = S() + d = self.data[-1][:-1] + data = [d + (s,), d + (s1,), d + (s2,)] + self.c.inserttable('test', data) + data = [data[0]] * 3 + self.assertEqual(self.get_back(), data) + class TestDirectSocketAccess(unittest.TestCase): """Test copy command with direct socket access.""" From d8c2be716baa58d4e51f24faac5b5f68ea61aca8 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 5 Feb 2022 22:21:26 +0100 Subject: [PATCH 095/194] Simplify inserttable() test for special chars --- tests/test_classic_connection.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 0a83f3e1..5304c783 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2146,11 +2146,10 @@ def __repr__(self): s = '1\'2"3\b4\f5\n6\r7\t8\b9\\0' s1 = s.encode('ascii') if unicode_strings else s.decode('ascii') s2 = S() - d = self.data[-1][:-1] - data = [d + (s,), d + (s1,), d + (s2,)] - self.c.inserttable('test', data) - data = [data[0]] * 3 - self.assertEqual(self.get_back(), data) + data = [(t,) for t in (s, s1, s2)] + self.c.inserttable('test', data, ['t']) + self.assertEqual( + self.c.query('select t from test').getresult(), [(s,)] * 3) class TestDirectSocketAccess(unittest.TestCase): From c55d00342d77ee99757fc18f372c0c80afc54f7f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 22 Mar 2022 20:01:01 +0100 Subject: [PATCH 096/194] Fix use after free issue in inserttable() (#71) --- docs/contents/changelog.rst | 6 ++++++ pgconn.c | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index d2591c51..8b1bbb75 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,12 @@ ChangeLog ========= +Version 5.2.4 (to be released) +------------------------------ +- Two more fixes in the `inserttable()` method of the `pg` module: + - `inserttable()` failed to escape carriage return (#68) + - Fix use after free issue in `inserttable()` (#71) + Version 5.2.3 (2022-01-30) -------------------------- - This version officially supports the new Python 3.10 and PostgreSQL 14. diff --git a/pgconn.c b/pgconn.c index b71d3db3..68b186af 100644 --- a/pgconn.c +++ b/pgconn.c @@ -775,7 +775,7 @@ conn_inserttable(connObject *self, PyObject *args) char *col; if (PyBytes_Check(obj)) { - PyBytes_AsStringAndSize(obj, &col, &slen); + Py_INCREF(obj); } else if (PyUnicode_Check(obj)) { obj = get_encoded_string(obj, encoding); @@ -783,8 +783,6 @@ conn_inserttable(connObject *self, PyObject *args) PyMem_Free(buffer); Py_DECREF(iter_row); return NULL; /* pass the UnicodeEncodeError */ } - PyBytes_AsStringAndSize(obj, &col, &slen); - Py_DECREF(obj); } else { PyErr_SetString( PyExc_TypeError, @@ -792,7 +790,9 @@ conn_inserttable(connObject *self, PyObject *args) PyMem_Free(buffer); Py_DECREF(iter_row); return NULL; } + PyBytes_AsStringAndSize(obj, &col, &slen); col = PQescapeIdentifier(self->cnx, col, (size_t) slen); + Py_DECREF(obj); if (bufpt < bufmax) bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "%s%s", col, j == n - 1 ? ")" : ","); From d2748c741c8df7bed6cf1f95a09fe82b88f19be2 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Mar 2022 15:25:31 +0100 Subject: [PATCH 097/194] Increase maximum buffer size to 64 KB (tsolves #69) --- docs/contents/changelog.rst | 4 +++- pgmodule.c | 2 +- tests/test_classic_connection.py | 23 ++++++++++++++++++++--- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 8b1bbb75..76bfaf28 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -3,9 +3,11 @@ ChangeLog Version 5.2.4 (to be released) ------------------------------ -- Two more fixes in the `inserttable()` method of the `pg` module: +- Three more fixes in the `inserttable()` method of the `pg` module: - `inserttable()` failed to escape carriage return (#68) + - Allow larger row sizes up to 64 KB (#69) - Fix use after free issue in `inserttable()` (#71) +- The `getline()` method of `pg` connections now also supports up to 64 KB size Version 5.2.3 (2022-01-30) -------------------------- diff --git a/pgmodule.c b/pgmodule.c index 3096fdd9..07982aad 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -57,7 +57,7 @@ static const char *PyPgVersion = TOSTRING(PYGRESQL_VERSION); #define QUERY_MOVENEXT 3 #define QUERY_MOVEPREV 4 -#define MAX_BUFFER_SIZE 8192 /* maximum transaction size */ +#define MAX_BUFFER_SIZE 65536 /* maximum transaction size */ #define MAX_ARRAY_DEPTH 16 /* maximum allowed depth of an array */ /* MODULE GLOBAL VARIABLES */ diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 5304c783..90a4375f 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2014,9 +2014,14 @@ def testInserttableWithInvalidColumList(self): self.assertFalse('expected an error') def testInserttableWithHugeListOfColumnNames(self): - # should catch buffer overflow when building the column specification - self.assertRaises(MemoryError, self.c.inserttable, - 'test', self.data, ['very_long_column_name'] * 1000) + data = self.data + # try inserting data with a huge list of column names + cols = ['very_long_column_name'] * 2000 + # Should raise a value error because the column does not exist + self.assertRaises(ValueError, self.c.inserttable, 'test', data, cols) + # double the size, should catch buffer overflow and raise memory error + cols *= 2 + self.assertRaises(MemoryError, self.c.inserttable, 'test', data, cols) def testInserttableMaxValues(self): data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), @@ -2151,6 +2156,18 @@ def __repr__(self): self.assertEqual( self.c.query('select t from test').getresult(), [(s,)] * 3) + def testInsertTableBigRowSize(self): + # inserting rows with a size of up to 64k bytes should work + t = '*' * 50000 + data = [(t,)] + self.c.inserttable('test', data, ['t']) + self.assertEqual( + self.c.query('select t from test').getresult(), data) + # double the size, should catch buffer overflow and raise memory error + t *= 2 + data = [(t,)] + self.assertRaises(MemoryError, self.c.inserttable, 'test', data, ['t']) + class TestDirectSocketAccess(unittest.TestCase): """Test copy command with direct socket access.""" From f538701a292cb7a72b474226b41c12cd73d91c32 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Mar 2022 15:31:16 +0100 Subject: [PATCH 098/194] Use the latest Sphinx version --- .github/workflows/docs.yml | 6 +++--- docs/requirements.txt | 2 +- tox.ini | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 1f5e042d..32ea4e43 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -12,16 +12,16 @@ jobs: steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.9 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.9 - name: Install dependencies run: | sudo apt install libpq-dev python -m pip install --upgrade pip pip install . - pip install "sphinx>=2.4,<3" + pip install "sphinx>=4.4,<5" pip install "cloud_sptheme>=1.10,<2" - name: Create docs with Sphinx run: | diff --git a/docs/requirements.txt b/docs/requirements.txt index dde5e355..a59b8f44 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,2 @@ -sphinx>=3.2,<4 +sphinx>=4.4,<5 cloud_sptheme>=1.10,<2 diff --git a/tox.ini b/tox.ini index d5a72960..dd98fa29 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,7 @@ commands = [testenv:docs] basepython = python3.9 deps = - sphinx>=3.5,<4 + sphinx>=4.4,<5 cloud_sptheme>=1.10,<2 commands = sphinx-build -b html -nEW docs docs/_build/html From deae56f688d81adc8913316f7a0884ef2d526296 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Mar 2022 17:47:04 +0100 Subject: [PATCH 099/194] Replace obsolete functions for copy used internally (#59) --- docs/contents/changelog.rst | 3 +- pgconn.c | 107 ++++++++++++++++++------------- tests/test_classic_connection.py | 7 +- 3 files changed, 66 insertions(+), 51 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 76bfaf28..3db4fb40 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -7,7 +7,8 @@ Version 5.2.4 (to be released) - `inserttable()` failed to escape carriage return (#68) - Allow larger row sizes up to 64 KB (#69) - Fix use after free issue in `inserttable()` (#71) -- The `getline()` method of `pg` connections now also supports up to 64 KB size +- Replace obsolete functions for copy used internally (#59). + Therefore, `getline()` now does not return `\.` at the end any more. Version 5.2.3 (2022-01-30) -------------------------- diff --git a/pgconn.c b/pgconn.c index 68b186af..6d12ede4 100644 --- a/pgconn.c +++ b/pgconn.c @@ -550,23 +550,27 @@ conn_putline(connObject *self, PyObject *args) { char *line; Py_ssize_t line_length; + int ret; if (!self->cnx) { PyErr_SetString(PyExc_TypeError, "Connection is not valid"); return NULL; } - /* reads args */ + /* read args */ if (!PyArg_ParseTuple(args, "s#", &line, &line_length)) { PyErr_SetString(PyExc_TypeError, "Method putline() takes a string argument"); return NULL; } - /* sends line to backend */ - if (PQputline(self->cnx, line)) { - PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); + /* send line to backend */ + ret = PQputCopyData(self->cnx, line, (int) line_length); + if (ret != 1) { + PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) : + "Line cannot be queued, wait for write-ready and try again"); return NULL; + } Py_INCREF(Py_None); return Py_None; @@ -579,29 +583,39 @@ static char conn_getline__doc__[] = static PyObject * conn_getline(connObject *self, PyObject *noargs) { - char line[MAX_BUFFER_SIZE]; - PyObject *str = NULL; /* GCC */ + char *line = NULL; + PyObject *str = NULL; + int ret; if (!self->cnx) { PyErr_SetString(PyExc_TypeError, "Connection is not valid"); return NULL; } - /* gets line */ - switch (PQgetline(self->cnx, line, MAX_BUFFER_SIZE)) { - case 0: - str = PyStr_FromString(line); - break; - case 1: - PyErr_SetString(PyExc_MemoryError, "Buffer overflow"); - str = NULL; - break; - case EOF: + /* get line synchronously */ + ret = PQgetCopyData(self->cnx, &line, 0); + + /* check result */ + if (ret <= 0) { + if (line != NULL) PQfreemem(line); + if (ret == -1) { + PQgetResult(self->cnx); Py_INCREF(Py_None); - str = Py_None; - break; + return Py_None; + } + PyErr_SetString(PyExc_MemoryError, + ret == -2 ? PQerrorMessage(self->cnx) : + "No line available, wait for read-ready and try again"); + return NULL; } - + if (line == NULL) { + Py_INCREF(Py_None); + return Py_None; + } + /* for backward compatibility, convert terminating newline to zero byte */ + if (*line) line[strlen(line) - 1] = '\0'; + str = PyStr_FromString(line); + PQfreemem(line); return str; } @@ -612,14 +626,20 @@ static char conn_endcopy__doc__[] = static PyObject * conn_endcopy(connObject *self, PyObject *noargs) { + int ret; + if (!self->cnx) { PyErr_SetString(PyExc_TypeError, "Connection is not valid"); return NULL; } - /* ends direct copy */ - if (PQendcopy(self->cnx)) { - PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); + /* end direct copy */ + ret = PQputCopyEnd(self->cnx, NULL); + if (ret != 1) + { + PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) : + "Termination message cannot be queued," + " wait for write-ready and try again"); return NULL; } Py_INCREF(Py_None); @@ -691,7 +711,7 @@ conn_inserttable(connObject *self, PyObject *args) { PGresult *result; char *table, *buffer, *bufpt, *bufmax, *s, *t; - int encoding; + int encoding, ret; size_t bufsiz; PyObject *rows, *iter_row, *item, *columns = NULL; Py_ssize_t i, j, m, n; @@ -714,8 +734,7 @@ conn_inserttable(connObject *self, PyObject *args) { PyErr_SetString( PyExc_TypeError, - "Method inserttable() expects an iterable" - " as second argument"); + "Method inserttable() expects an iterable as second argument"); return NULL; } m = PySequence_Check(rows) ? PySequence_Size(rows) : -1; @@ -824,7 +843,7 @@ conn_inserttable(connObject *self, PyObject *args) if (!(columns = PyIter_Next(iter_row))) break; if (!(PyTuple_Check(columns) || PyList_Check(columns))) { - PQendcopy(self->cnx); PyMem_Free(buffer); + PQputCopyEnd(self->cnx, "Invalid arguments"); PyMem_Free(buffer); Py_DECREF(columns); Py_DECREF(columns); Py_DECREF(iter_row); PyErr_SetString( PyExc_TypeError, @@ -836,7 +855,7 @@ conn_inserttable(connObject *self, PyObject *args) if (n < 0) { n = j; } else if (j != n) { - PQendcopy(self->cnx); PyMem_Free(buffer); + PQputCopyEnd(self->cnx, "Invalid arguments"); PyMem_Free(buffer); Py_DECREF(columns); Py_DECREF(iter_row); PyErr_SetString( PyExc_TypeError, @@ -894,7 +913,8 @@ conn_inserttable(connObject *self, PyObject *args) else if (PyUnicode_Check(item)) { PyObject *s = get_encoded_string(item, encoding); if (!s) { - PQendcopy(self->cnx); PyMem_Free(buffer); + PQputCopyEnd(self->cnx, "Encoding error"); + PyMem_Free(buffer); Py_DECREF(item); Py_DECREF(columns); Py_DECREF(iter_row); return NULL; /* pass the UnicodeEncodeError */ } @@ -967,40 +987,39 @@ conn_inserttable(connObject *self, PyObject *args) } if (bufsiz <= 0) { - PQendcopy(self->cnx); PyMem_Free(buffer); + PQputCopyEnd(self->cnx, "Memory error"); PyMem_Free(buffer); Py_DECREF(columns); Py_DECREF(iter_row); return PyErr_NoMemory(); } } - Py_DECREF(columns); + Py_DECREF(columns); - *bufpt++ = '\n'; *bufpt = '\0'; + *bufpt++ = '\n'; /* sends data */ - if (PQputline(self->cnx, buffer)) { - PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); - PQendcopy(self->cnx); PyMem_Free(buffer); Py_DECREF(iter_row); + ret = PQputCopyData(self->cnx, buffer, (int) (bufpt - buffer)); + if (ret != 1) { + char *errormsg = ret == - 1 ? + PQerrorMessage(self->cnx) : "Data cannot be queued"; + PyErr_SetString(PyExc_IOError, errormsg); + PQputCopyEnd(self->cnx, errormsg); + PyMem_Free(buffer); Py_DECREF(iter_row); return NULL; } } Py_DECREF(iter_row); if (PyErr_Occurred()) { - PQendcopy(self->cnx); PyMem_Free(buffer); + PQerrorMessage(self->cnx); PyMem_Free(buffer); return NULL; /* pass the iteration error */ } - /* ends query */ - if (PQputline(self->cnx, "\\.\n")) { - PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); - PQendcopy(self->cnx); PyMem_Free(buffer); - return NULL; - } - - if (PQendcopy(self->cnx)) { - PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); + ret = PQputCopyEnd(self->cnx, NULL); + if (ret != 1) { + PyErr_SetString(PyExc_IOError, ret == -1 ? + PQerrorMessage(self->cnx) : "Data cannot be queued"); PyMem_Free(buffer); return NULL; } diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 90a4375f..32c21870 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2205,7 +2205,6 @@ def testPutline(self): try: for i, v in data: putline("%d\t%s\n" % (i, v)) - putline("\\.\n") finally: self.c.endcopy() r = query("select * from test").getresult() @@ -2222,7 +2221,6 @@ def testPutlineBytesAndUnicode(self): try: putline(u"47\tkäse\n".encode('utf8')) putline("35\twürstel\n") - putline(b"\\.\n") finally: self.c.endcopy() r = query("select * from test").getresult() @@ -2236,14 +2234,12 @@ def testGetline(self): self.c.inserttable('test', data) query("copy test to stdout") try: - for i in range(n + 2): + for i in range(n + 1): v = getline() if i < n: # noinspection PyStringFormat self.assertEqual(v, '%d\t%s' % data[i]) elif i == n: - self.assertEqual(v, '\\.') - else: self.assertIsNone(v) finally: try: @@ -2268,7 +2264,6 @@ def testGetlineBytesAndUnicode(self): v = getline() self.assertIsInstance(v, str) self.assertEqual(v, '73\twürstel') - self.assertEqual(getline(), '\\.') self.assertIsNone(getline()) finally: try: From 45e1dbab363f0ffb8b5d8f1b67c64053fb5770da Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Mar 2022 18:20:31 +0100 Subject: [PATCH 100/194] Bump version number and add release date --- .bumpversion.cfg | 2 +- docs/about.txt | 2 +- docs/announce.rst | 4 ++-- docs/conf.py | 2 +- docs/contents/changelog.rst | 4 ++-- setup.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 8715d670..31f5835a 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 5.2.3 +current_version = 5.2.4 commit = False tag = False diff --git a/docs/about.txt b/docs/about.txt index 9379b9a6..4bcd830f 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -36,7 +36,7 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.2.3 needs PostgreSQL 9.0 to 9.6 or 10 to 14, and +The current version PyGreSQL 5.2.4 needs PostgreSQL 9.0 to 9.6 or 10 to 14, and Python 2.7 or 3.5 to 3.10. If you need to support older PostgreSQL versions or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index 9db1c62d..0c90d212 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -3,10 +3,10 @@ PyGreSQL Announcements ====================== --------------------------------- -Release of PyGreSQL version 5.2.3 +Release of PyGreSQL version 5.2.4 --------------------------------- -Release 5.2.3 of PyGreSQL. +Release 5.2.4 of PyGreSQL. It is available at: https://pypi.org/project/PyGreSQL/. diff --git a/docs/conf.py b/docs/conf.py index 721c6bf8..6ea28189 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,7 +68,7 @@ # built documents. # # The full version, including alpha/beta/rc tags. -version = release = '5.2.3' +version = release = '5.2.4' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 3db4fb40..975ad682 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,8 +1,8 @@ ChangeLog ========= -Version 5.2.4 (to be released) ------------------------------- +Version 5.2.4 (2022-03-26) +-------------------------- - Three more fixes in the `inserttable()` method of the `pg` module: - `inserttable()` failed to escape carriage return (#68) - Allow larger row sizes up to 64 KB (#69) diff --git a/setup.py b/setup.py index bc1de77c..dcc06594 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # # Please see the LICENSE.TXT file for specific restrictions. -"""Setup script for PyGreSQL version 5.2.3 +"""Setup script for PyGreSQL version 5.2.4 PyGreSQL is an open-source Python module that interfaces to a PostgreSQL database. It embeds the PostgreSQL query library to allow @@ -52,7 +52,7 @@ from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib -version = '5.2.3' +version = '5.2.4' if not (sys.version_info[:2] == (2, 7) or (3, 5) <= sys.version_info[:2] < (4, 0)): From 373e1a93ca33a8b4254772bbddf360ce989d907c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Jul 2022 21:35:03 +0200 Subject: [PATCH 101/194] Fixed wording of the project description (#75) --- README.rst | 4 ++-- docs/about.txt | 8 ++++---- setup.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.rst b/README.rst index ad7c9b99..98bb30bb 100644 --- a/README.rst +++ b/README.rst @@ -2,8 +2,8 @@ PyGreSQL - Python interface for PostgreSQL ========================================== PyGreSQL is a Python module that interfaces to a PostgreSQL database. -It embeds the PostgreSQL query library to allow easy use of the powerful -PostgreSQL features from a Python script. +It wraps the lower level C API library libpq to allow easy use of the +powerful PostgreSQL features from Python. PyGreSQL should run on most platforms where PostgreSQL and Python is running. It is based on the PyGres95 code written by Pascal Andre. diff --git a/docs/about.txt b/docs/about.txt index 4bcd830f..3463b3b7 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -1,7 +1,7 @@ **PyGreSQL** is an *open-source* `Python `_ module that interfaces to a `PostgreSQL `_ database. -It embeds the PostgreSQL query library to allow easy use of the powerful -PostgreSQL features from a Python script. +It wraps the lower level C API library libpq to allow easy use of the +powerful PostgreSQL features from Python. | This software is copyright © 1995, Pascal Andre. | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. @@ -27,8 +27,8 @@ The Python implementation is copyrighted but freely usable and distributable, even for commercial use. **PyGreSQL** is a Python module that interfaces to a PostgreSQL database. -It embeds the PostgreSQL query library to allow easy use of the powerful -PostgreSQL features from a Python script or application. +It wraps the lower level C API library libpq to allow easy use of the +powerful PostgreSQL features from Python. PyGreSQL is developed and tested on a NetBSD system, but it also runs on most other platforms where PostgreSQL and Python is running. It is based diff --git a/setup.py b/setup.py index dcc06594..3cfbe278 100755 --- a/setup.py +++ b/setup.py @@ -9,8 +9,8 @@ """Setup script for PyGreSQL version 5.2.4 PyGreSQL is an open-source Python module that interfaces to a -PostgreSQL database. It embeds the PostgreSQL query library to allow -easy use of the powerful PostgreSQL features from a Python script. +PostgreSQL database. It wraps the lower level C API library libpq +to allow easy use of the powerful PostgreSQL features from Python. Authors and history: * PyGreSQL written 1997 by D'Arcy J.M. Cain From 03ff2ad03776bd1f96e8e84b6a330aaf6cf370ba Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 14:38:49 +0200 Subject: [PATCH 102/194] Support Python 3.10 and Postgres 15 --- .gitignore | 2 ++ LICENSE.txt | 2 +- docs/about.txt | 6 +++--- docs/announce.rst | 4 ++-- docs/conf.py | 2 +- docs/contents/changelog.rst | 4 ++++ docs/contents/install.rst | 2 +- docs/copyright.rst | 2 +- pg.py | 6 +++--- pgconn.c | 2 +- pgdb.py | 2 +- pginternal.c | 2 +- pglarge.c | 2 +- pgmodule.c | 2 +- pgnotice.c | 2 +- pgquery.c | 2 +- pgsource.c | 2 +- setup.py | 7 ++++--- tests/test_classic_connection.py | 6 +++--- tests/test_classic_dbwrapper.py | 2 +- tests/test_classic_functions.py | 2 +- tox.ini | 10 +++++----- 22 files changed, 40 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 67e3ae35..71300f9e 100644 --- a/.gitignore +++ b/.gitignore @@ -23,11 +23,13 @@ _build_doctrees/ docker-compose.yml Dockerfile Vagrantfile +Vagrantfile-* .coverage .tox/ .venv/ .vagrant/ +.vagrant-*/ Thumbs.db .DS_Store diff --git a/LICENSE.txt b/LICENSE.txt index c10a5870..eea706fe 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -6,7 +6,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain -Further modifications copyright (c) 2009-2022 by the PyGreSQL Development Team +Further modifications copyright (c) 2009-2023 by the PyGreSQL Development Team PyGreSQL is released under the PostgreSQL License, a liberal Open Source license, similar to the BSD or MIT licenses: diff --git a/docs/about.txt b/docs/about.txt index 3463b3b7..d1492061 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -5,7 +5,7 @@ powerful PostgreSQL features from Python. | This software is copyright © 1995, Pascal Andre. | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2022 by the PyGreSQL team. + | Further modifications are copyright © 2009-2023 by the PyGreSQL team. | For licensing details, see the full :doc:`copyright`. **PostgreSQL** is a highly scalable, SQL compliant, open source @@ -36,7 +36,7 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.2.4 needs PostgreSQL 9.0 to 9.6 or 10 to 14, and -Python 2.7 or 3.5 to 3.10. If you need to support older PostgreSQL versions or +The current version PyGreSQL 5.2.4 needs PostgreSQL 9.0 to 9.6 or 10 to 15, and +Python 2.7 or 3.5 to 3.11. If you need to support older PostgreSQL versions or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index 0c90d212..cadf376b 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -22,8 +22,8 @@ This version has been built and unit tested on: - openSUSE - Ubuntu - Windows 7 and 10 with both MinGW and Visual Studio - - PostgreSQL 9.0 to 9.6 and 10 to 14 (32 and 64bit) - - Python 2.7 and 3.5 to 3.10 (32 and 64bit) + - PostgreSQL 9.0 to 9.6 and 10 to 15 (32 and 64bit) + - Python 2.7 and 3.5 to 3.11 (32 and 64bit) | D'Arcy J.M. Cain | darcy@PyGreSQL.org diff --git a/docs/conf.py b/docs/conf.py index 6ea28189..6a9f87e0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -61,7 +61,7 @@ # General information about the project. project = 'PyGreSQL' author = 'The PyGreSQL team' -copyright = '2022, ' + author +copyright = '2023, ' + author # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 975ad682..0944b349 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,10 @@ ChangeLog ========= +Version 5.2.5 (to be released) +------------------------------ +- This version officially supports the new Python 3.11 and PostgreSQL 15. + Version 5.2.4 (2022-03-26) -------------------------- - Three more fixes in the `inserttable()` method of the `pg` module: diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 1b7ef55e..4ef323af 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -14,7 +14,7 @@ On Windows, you also need to make sure that the directory that contains ``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -2.7 and 3.5 to 3.10, and PostgreSQL versions 9.0 to 9.6 and 10 to 14. +2.7 and 3.5 to 3.11, and PostgreSQL versions 9.0 to 9.6 and 10 to 15. PyGreSQL will be installed as three modules, a shared library called ``_pg.so`` (on Linux) or a DLL called ``_pg.pyd`` (on Windows), and two pure diff --git a/docs/copyright.rst b/docs/copyright.rst index 4c9aacc6..9a8113ec 100644 --- a/docs/copyright.rst +++ b/docs/copyright.rst @@ -10,7 +10,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain (darcy@PyGreSQL.org) -Further modifications copyright (c) 2009-2022 by the PyGreSQL team. +Further modifications copyright (c) 2009-2023 by the PyGreSQL team. Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/pg.py b/pg.py index b03c3d71..fbd97725 100644 --- a/pg.py +++ b/pg.py @@ -4,7 +4,7 @@ # # This file contains the classic pg module. # -# Copyright (c) 2022 by the PyGreSQL Development Team +# Copyright (c) 2023 by the PyGreSQL Development Team # # The notification handler is based on pgnotify which is # Copyright (c) 2001 Ng Pheng Siong. All rights reserved. @@ -99,7 +99,7 @@ try: # noinspection PyUnresolvedReferences - from typing import Dict, List, Union + from typing import Dict, List, Union # noqa: F401 has_typing = True except ImportError: # Python < 3.5 has_typing = False @@ -1934,7 +1934,7 @@ def set_parameter(self, parameter, value=None, local=False): value = set(value) if len(value) == 1: value = value.pop() - if not(value is None or isinstance(value, basestring)): + if not (value is None or isinstance(value, basestring)): raise ValueError( 'A single value must be specified' ' when parameter is a set') diff --git a/pgconn.c b/pgconn.c index 6d12ede4..e8548d32 100644 --- a/pgconn.c +++ b/pgconn.c @@ -3,7 +3,7 @@ * * The connection object - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgdb.py b/pgdb.py index 7b0eaefc..7eaf9cb0 100644 --- a/pgdb.py +++ b/pgdb.py @@ -4,7 +4,7 @@ # # This file contains the DB-API 2 compatible pgdb module. # -# Copyright (c) 2022 by the PyGreSQL Development Team +# Copyright (c) 2023 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions. diff --git a/pginternal.c b/pginternal.c index 91c565be..6dcad8bc 100644 --- a/pginternal.c +++ b/pginternal.c @@ -3,7 +3,7 @@ * * Internal functions - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pglarge.c b/pglarge.c index ed8f1824..c080d658 100644 --- a/pglarge.c +++ b/pglarge.c @@ -3,7 +3,7 @@ * * Large object support - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgmodule.c b/pgmodule.c index 07982aad..bbb4b0db 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -3,7 +3,7 @@ * * This is the main file for the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgnotice.c b/pgnotice.c index 7e5b93c7..ae6b2b68 100644 --- a/pgnotice.c +++ b/pgnotice.c @@ -3,7 +3,7 @@ * * The notice object - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgquery.c b/pgquery.c index a04eb68b..852c848b 100644 --- a/pgquery.c +++ b/pgquery.c @@ -3,7 +3,7 @@ * * The query object - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgsource.c b/pgsource.c index 4fa04365..053ad02f 100644 --- a/pgsource.c +++ b/pgsource.c @@ -3,7 +3,7 @@ * * The source object - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/setup.py b/setup.py index 3cfbe278..cdb20c4f 100755 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ # # PyGreSQL - a Python interface for the PostgreSQL database. # -# Copyright (c) 2022 by the PyGreSQL Development Team +# Copyright (c) 2023 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions. @@ -26,8 +26,8 @@ * PostgreSQL pg_config tool (usually included in the devel package) (the Windows installer has it as part of the database server feature) -PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.10, -and PostgreSQL versions 9.0 to 9.6 and 10 to 14. +PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.11, +and PostgreSQL versions 9.0 to 9.6 and 10 to 15. Use as follows: python setup.py build_ext # to build the module @@ -252,6 +252,7 @@ def finalize_options(self): 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', "Programming Language :: SQL", "Topic :: Database", "Topic :: Database :: Front-Ends", diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 32c21870..bd423d91 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -195,7 +195,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 150000) + self.assertTrue(90000 <= server_version < 160000) def testAttributeSocket(self): socket = self.connection.socket @@ -871,7 +871,7 @@ def testGetresultUtf8(self): # pass the query as unicode try: v = self.c.query(q).getresult()[0][0] - except(pg.DataError, pg.NotSupportedError): + except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") v = None self.assertIsInstance(v, str) @@ -2623,7 +2623,7 @@ def testSetBool(self): finally: pg.set_bool(use_bool) self.assertIsInstance(r, str) - self.assertIs(r, 't') + self.assertEqual(r, 't') pg.set_bool(True) try: r = query("select true::bool").getresult()[0][0] diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index dfb2eecd..ca87a607 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -252,7 +252,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 150000) + self.assertTrue(90000 <= server_version < 160000) self.assertEqual(server_version, self.db.db.server_version) def testAttributeSocket(self): diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index d7c7a720..653fbb87 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -126,7 +126,7 @@ def testPqlibVersion(self): v = pg.get_pqlib_version() self.assertIsInstance(v, long) self.assertGreater(v, 90000) - self.assertLess(v, 150000) + self.assertLess(v, 160000) class TestParseArray(unittest.TestCase): diff --git a/tox.ini b/tox.ini index dd98fa29..67dee9fd 100644 --- a/tox.ini +++ b/tox.ini @@ -1,18 +1,18 @@ # config file for tox [tox] -envlist = py27,py3{5,6,7,8,9,10},flake8,docs +envlist = py27,py3{5,6,7,8,9,10,11},flake8,docs [testenv:flake8] -basepython = python3.9 -deps = flake8>=4,<5 +basepython = python3.11 +deps = flake8>=6,<7 commands = flake8 setup.py pg.py pgdb.py tests [testenv:docs] -basepython = python3.9 +basepython = python3.11 deps = - sphinx>=4.4,<5 + sphinx>=4.5,<5 cloud_sptheme>=1.10,<2 commands = sphinx-build -b html -nEW docs docs/_build/html From e142a51996cd098d1038f7945849e1201f5b04e0 Mon Sep 17 00:00:00 2001 From: justinpryzby Date: Sat, 26 Aug 2023 07:47:02 -0500 Subject: [PATCH 103/194] inserttable: test for errors and return number of tuples as str (#73) Contributed by: Justin Pryzby --- docs/contents/pg/connection.rst | 3 ++- pgconn.c | 17 +++++++++++++---- tests/test_classic_connection.py | 5 +++++ tests/test_classic_dbwrapper.py | 3 ++- tests/test_tutorial.py | 3 ++- 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index a9fccfdf..c95adf59 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -494,7 +494,7 @@ inserttable -- insert an iterable into a table :param str table: the table name :param list values: iterable of row values, which must be lists or tuples :param list columns: list or tuple of column names - :rtype: None + :rtype: int :raises TypeError: invalid connection, bad argument type, or too many arguments :raises MemoryError: insert buffer could not be allocated :raises ValueError: unsupported values @@ -506,6 +506,7 @@ of the same size, containing the values for each inserted row. These may contain string, integer, long or double (real) values. ``columns`` is an optional tuple or list of column names to be passed on to the COPY command. +The number of rows affected is returned. .. warning:: diff --git a/pgconn.c b/pgconn.c index 6d12ede4..eb86abed 100644 --- a/pgconn.c +++ b/pgconn.c @@ -1012,7 +1012,7 @@ conn_inserttable(connObject *self, PyObject *args) Py_DECREF(iter_row); if (PyErr_Occurred()) { - PQerrorMessage(self->cnx); PyMem_Free(buffer); + PyMem_Free(buffer); return NULL; /* pass the iteration error */ } @@ -1026,9 +1026,18 @@ conn_inserttable(connObject *self, PyObject *args) PyMem_Free(buffer); - /* no error : returns nothing */ - Py_INCREF(Py_None); - return Py_None; + Py_BEGIN_ALLOW_THREADS + result = PQgetResult(self->cnx); + Py_END_ALLOW_THREADS + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); + PQclear(result); + return NULL; + } else { + long ntuples = atol(PQcmdTuples(result)); + PQclear(result); + return PyInt_FromLong(ntuples); + } } /* Get transaction state. */ diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 32c21870..1fea33be 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2023,6 +2023,11 @@ def testInserttableWithHugeListOfColumnNames(self): cols *= 2 self.assertRaises(MemoryError, self.c.inserttable, 'test', data, cols) + def testInserttableWithOutOfRangeData(self): + # try inserting data out of range for the column type + # Should raise a value error because of smallint out of range + self.assertRaises(ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) + def testInserttableMaxValues(self): data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), True, '2999-12-31', '11:59:59', 1e99, diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index dfb2eecd..bd481e09 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4227,10 +4227,11 @@ def testInserttableFromQuery(self): self.createTable('test_table_to', 'n integer, t timestamp') for i in range(1, 4): query("insert into test_table_from values ($1, now())", i) - self.db.inserttable( + n = self.db.inserttable( 'test_table_to', query("select n, t::text from test_table_from")) data_from = query("select * from test_table_from").getresult() data_to = query("select * from test_table_to").getresult() + self.assertEqual(n, 3) self.assertEqual([row[0] for row in data_from], [1, 2, 3]) self.assertEqual(data_from, data_to) diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index dbc93024..6f968560 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -48,7 +48,8 @@ def test_all_steps(self): self.assertEqual(r, {'name': 'banana', 'id': 2}) more_fruits = 'cherimaya durian eggfruit fig grapefruit'.split() data = list(enumerate(more_fruits, start=3)) - db.inserttable('fruits', data) + n = db.inserttable('fruits', data) + self.assertEqual(n, 5) q = db.query('select * from fruits') r = str(q).splitlines() self.assertEqual(r[0], 'id| name ') From 2871d27cad0a404b74b40012ccbed48ad500452f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 15:04:06 +0200 Subject: [PATCH 104/194] Update GitHub workflows --- .github/workflows/docs.yml | 12 ++++++------ .github/workflows/lint.yml | 8 ++++---- .github/workflows/tests.yml | 22 ++++++++++++---------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 32ea4e43..ec18c7ba 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -8,20 +8,20 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v1 - - name: Set up Python 3.9 - uses: actions/setup-python@v1 + - uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 - name: Install dependencies run: | sudo apt install libpq-dev python -m pip install --upgrade pip pip install . - pip install "sphinx>=4.4,<5" + pip install "sphinx>=4.5,<5" pip install "cloud_sptheme>=1.10,<2" - name: Create docs with Sphinx run: | diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 437449a1..205d8b54 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,18 +7,18 @@ on: jobs: checks: name: Quality checks run - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 strategy: fail-fast: false steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Install tox run: pip install tox - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 - name: Run quality checks run: tox -e flake8,docs timeout-minutes: 5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b6eeecd1..e9269da7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,7 +9,7 @@ on: jobs: tests: name: Unit tests run - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 strategy: fail-fast: false @@ -22,15 +22,17 @@ jobs: - {python: "3.8", postgres: "12"} - {python: "3.9", postgres: "13"} - {python: "3.10", postgres: "14"} + - {python: "3.11", postgres: "15"} # Opposite extremes of the supported Py/PG range, other architecture - - {python: "2.7", postgres: "14", architecture: "x86"} - - {python: "3.5", postgres: "13", architecture: "x86"} - - {python: "3.6", postgres: "12", architecture: "x86"} - - {python: "3.7", postgres: "11", architecture: "x86"} - - {python: "3.8", postgres: "10", architecture: "x86"} - - {python: "3.9", postgres: "9.6", architecture: "x86"} - - {python: "3.10", postgres: "9.3", architecture: "x86"} + - {python: "2.7", postgres: "15", architecture: "x86"} + - {python: "3.5", postgres: "14", architecture: "x86"} + - {python: "3.6", postgres: "13", architecture: "x86"} + - {python: "3.7", postgres: "12", architecture: "x86"} + - {python: "3.8", postgres: "11", architecture: "x86"} + - {python: "3.9", postgres: "10", architecture: "x86"} + - {python: "3.10", postgres: "9.6", architecture: "x86"} + - {python: "3.11", postgres: "9.3", architecture: "x86"} env: PYGRESQL_DB: test @@ -54,10 +56,10 @@ jobs: --health-retries 5 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Install tox run: pip install tox - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Run tests From e29ca9ac798b440f8cb65abc82b17ae4759740b1 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 15:10:24 +0200 Subject: [PATCH 105/194] Fix tox issue with passenv --- tox.ini | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 67dee9fd..917e22c0 100644 --- a/tox.ini +++ b/tox.ini @@ -18,7 +18,9 @@ commands = sphinx-build -b html -nEW docs docs/_build/html [testenv] -passenv = PG* PYGRESQL_* +passenv = + PG* + PYGRESQL_* commands = python setup.py clean --all build_ext --force --inplace --strict --ssl-info --memory-size python -m unittest discover {posargs} From a2f51859d36a25cf36854d259689b1217f94c639 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 15:17:33 +0200 Subject: [PATCH 106/194] Remove desupported versions from workflow --- .github/workflows/tests.yml | 19 ++++++++----------- tests/test_classic_connection.py | 3 ++- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e9269da7..46eac7c0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,9 +15,9 @@ jobs: fail-fast: false matrix: include: - - {python: "2.7", postgres: "9.3"} - - {python: "3.5", postgres: "9.6"} - - {python: "3.6", postgres: "10"} + # - {python: "2.7", postgres: "9.3"} + # - {python: "3.5", postgres: "9.6"} + # - {python: "3.6", postgres: "10"} - {python: "3.7", postgres: "11"} - {python: "3.8", postgres: "12"} - {python: "3.9", postgres: "13"} @@ -25,14 +25,11 @@ jobs: - {python: "3.11", postgres: "15"} # Opposite extremes of the supported Py/PG range, other architecture - - {python: "2.7", postgres: "15", architecture: "x86"} - - {python: "3.5", postgres: "14", architecture: "x86"} - - {python: "3.6", postgres: "13", architecture: "x86"} - - {python: "3.7", postgres: "12", architecture: "x86"} - - {python: "3.8", postgres: "11", architecture: "x86"} - - {python: "3.9", postgres: "10", architecture: "x86"} - - {python: "3.10", postgres: "9.6", architecture: "x86"} - - {python: "3.11", postgres: "9.3", architecture: "x86"} + - {python: "3.7", postgres: "15", architecture: "x86"} + - {python: "3.8", postgres: "14", architecture: "x86"} + - {python: "3.9", postgres: "13", architecture: "x86"} + - {python: "3.10", postgres: "12", architecture: "x86"} + - {python: "3.11", postgres: "11", architecture: "x86"} env: PYGRESQL_DB: test diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 01932f39..a66af902 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2026,7 +2026,8 @@ def testInserttableWithHugeListOfColumnNames(self): def testInserttableWithOutOfRangeData(self): # try inserting data out of range for the column type # Should raise a value error because of smallint out of range - self.assertRaises(ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) + self.assertRaises( + ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) def testInserttableMaxValues(self): data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), From 538bb9f40d2e78c845036b4dc9cfaa56e621b707 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 21:14:18 +0200 Subject: [PATCH 107/194] Support generated columns in classic module (#83) --- docs/contents/changelog.rst | 7 ++ docs/contents/pg/db_wrapper.rst | 14 ++++ pg.py | 55 +++++++++++-- tests/test_classic_dbwrapper.py | 140 +++++++++++++++++++++++++++++++- 4 files changed, 207 insertions(+), 9 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 0944b349..a9b1b4fe 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -4,6 +4,13 @@ ChangeLog Version 5.2.5 (to be released) ------------------------------ - This version officially supports the new Python 3.11 and PostgreSQL 15. +- Two more improvements in the `inserttable()` method of the `pg` module + (thanks to Justin Pryzby for this contribution): + - error handling has been improved (#72) + - the method now returns the number of inserted rows (#73) +- Another improvement in the `pg` module (#83): + - generated columns can be requested with the `get_generated()` method + - generated columns are ignored by the insert, update and upsert method Version 5.2.4 (2022-03-26) -------------------------- diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index d2ef4e05..5d587f97 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -136,6 +136,20 @@ By default, only a limited number of simple types will be returned. You can get the registered types instead, if enabled by calling the :meth:`DB.use_regtypes` method. +get_generated -- get the generated columns of a table +----------------------------------------------------- + +.. method:: DB.get_generated(table) + + Get the generated columns of a table + + :param str table: name of table + :returns: an frozenset of column names + +Given the name of a table, digs out the set of generated columns. + +.. versionadded:: 5.2.5 + has_table_privilege -- check table privilege -------------------------------------------- diff --git a/pg.py b/pg.py index fbd97725..371c616b 100644 --- a/pg.py +++ b/pg.py @@ -1629,6 +1629,7 @@ def __init__(self, *args, **kw): self.dbname = db.db self._regtypes = False self._attnames = {} + self._generated = {} self._pkeys = {} self._privileges = {} self.adapter = Adapter(self) @@ -1657,6 +1658,17 @@ def __init__(self, *args, **kw): " WHERE a.attrelid OPERATOR(pg_catalog.=)" " %s::pg_catalog.regclass" " AND %s AND NOT a.attisdropped ORDER BY a.attnum") + if db.server_version < 100000: + self._query_generated = None + elif db.server_version < 120000: + self._query_generated = ( + "a.attidentity OPERATOR(pg_catalog.=) 'a'" + ) + else: + self._query_generated = ( + "(a.attidentity OPERATOR(pg_catalog.=) 'a' OR" + " a.attgenerated OPERATOR(pg_catalog.!=) '')" + ) db.set_cast_hook(self.dbtypes.typecast) # For debugging scripts, self.debug can be set # * to a string format specification (e.g. in CGI set to "%s
"), @@ -2130,7 +2142,7 @@ def get_relations(self, kinds=None, system=False): """Get list of relations in connected database of specified kinds. If kinds is None or empty, all kinds of relations are returned. - Otherwise kinds can be a string or sequence of type letters + Otherwise, kinds can be a string or sequence of type letters specifying which kind of relations you want to list. Set the system flag if you want to get the system relations as well. @@ -2190,6 +2202,32 @@ def get_attnames(self, table, with_oid=True, flush=False): attnames[table] = names # cache it return names + def get_generated(self, table, flush=False): + """Given the name of a table, dig out the set of generated columns. + + Returns a set of column names that are generated and unalterable. + + If flush is set, then the internal cache for generated columns will + be flushed. This may be necessary after the database schema or + the search path has been changed. + """ + query_generated = self._query_generated + if not query_generated: + return frozenset() + generated = self._generated + if flush: + generated.clear() + self._do_debug('The generated cache has been flushed') + try: # cache lookup + names = generated[table] + except KeyError: # cache miss, check the database + q = "a.attnum OPERATOR(pg_catalog.>) 0 AND " + query_generated + q = self._query_attnames % (_quote_if_unqualified('$1', table), q) + names = self.db.query(q, (table,)).getresult() + names = frozenset(name[0] for name in names) + generated[table] = names # cache it + return names + def use_regtypes(self, regtypes=None): """Use registered type names instead of simplified type names.""" if regtypes is None: @@ -2307,8 +2345,8 @@ def insert(self, table, row=None, **kw): be passed as the first parameter. The other parameters are used for providing the data of the row that shall be inserted into the table. If a dictionary is supplied as the second parameter, it starts with - that. Otherwise it uses a blank dictionary. Either way the dictionary - is updated from the keywords. + that. Otherwise, it uses a blank dictionary. + Either way the dictionary is updated from the keywords. The dictionary is then reloaded with the values actually inserted in order to pick up values modified by rules, triggers, etc. @@ -2321,13 +2359,14 @@ def insert(self, table, row=None, **kw): if 'oid' in row: del row['oid'] # do not insert oid attnames = self.get_attnames(table) + generated = self.get_generated(table) qoid = _oid_key(table) if 'oid' in attnames else None params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier names, values = [], [] for n in attnames: - if n in row: + if n in row and n not in generated: names.append(col(n)) values.append(adapt(row[n], attnames[n])) if not names: @@ -2360,6 +2399,7 @@ def update(self, table, row=None, **kw): if table.endswith('*'): table = table[:-1].rstrip() # need parent table name attnames = self.get_attnames(table) + generated = self.get_generated(table) qoid = _oid_key(table) if 'oid' in attnames else None if row is None: row = {} @@ -2390,7 +2430,7 @@ def update(self, table, row=None, **kw): values = [] keyname = set(keyname) for n in attnames: - if n in row and n not in keyname: + if n in row and n not in keyname and n not in generated: values.append('%s = %s' % (col(n), adapt(row[n], attnames[n]))) if not values: return row @@ -2461,13 +2501,14 @@ def upsert(self, table, row=None, **kw): if 'oid' in kw: del kw['oid'] # do not update oid attnames = self.get_attnames(table) + generated = self.get_generated(table) qoid = _oid_key(table) if 'oid' in attnames else None params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier names, values = [], [] for n in attnames: - if n in row: + if n in row and n not in generated: names.append(col(n)) values.append(adapt(row[n], attnames[n])) names, values = ', '.join(names), ', '.join(values) @@ -2480,7 +2521,7 @@ def upsert(self, table, row=None, **kw): keyname = set(keyname) keyname.add('oid') for n in attnames: - if n not in keyname: + if n not in keyname and n not in generated: value = kw.get(n, n in row) if value: if not isinstance(value, basestring): diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 3fa2db69..e97a23e2 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -184,8 +184,8 @@ def testAllDBAttributes(self): 'escape_literal', 'escape_string', 'fileno', 'get', 'get_as_dict', 'get_as_list', - 'get_attnames', 'get_cast_hook', - 'get_databases', 'get_notice_receiver', + 'get_attnames', 'get_cast_hook', 'get_databases', + 'get_generated', 'get_notice_receiver', 'get_parameter', 'get_relations', 'get_tables', 'getline', 'getlo', 'getnotify', 'has_table_privilege', 'host', @@ -1473,6 +1473,53 @@ def testGetAttnamesIsAttrDict(self): r = ' '.join(list(r.keys())) self.assertEqual(r, 'n alpha v gamma tau beta') + def testGetGenerated(self): + get_generated = self.db.get_generated + server_version = self.db.server_version + if server_version >= 100000: + self.assertRaises(pg.ProgrammingError, + self.db.get_generated, 'does_not_exist') + self.assertRaises(pg.ProgrammingError, + self.db.get_generated, 'has.too.many.dots') + r = get_generated('test') + self.assertIsInstance(r, frozenset) + self.assertFalse(r) + if server_version >= 100000: + table = 'test_get_generated_1' + self.createTable( + table, + 'i int generated always as identity primary key,' + ' j int generated always as identity,' + ' k int generated by default as identity,' + ' n serial, m int') + r = get_generated(table) + self.assertIsInstance(r, frozenset) + self.assertEqual(r, {'i', 'j'}) + if server_version >= 120000: + table = 'test_get_generated_2' + self.createTable( + table, + 'n int, m int generated always as (n + 3) stored,' + ' i int generated always as identity,' + ' j int generated by default as identity') + r = get_generated(table) + self.assertIsInstance(r, frozenset) + self.assertEqual(r, {'m', 'i'}) + + def testGetGeneratedIsCached(self): + server_version = self.db.server_version + if server_version < 100000: + return + get_generated = self.db.get_generated + query = self.db.query + table = 'test_get_generated_2' + self.createTable(table, 'i int primary key') + self.assertFalse(get_generated(table)) + query('alter table %s alter column i' + ' add generated always as identity' % table) + self.assertFalse(get_generated(table)) + self.assertEqual(get_generated(table, flush=True), {'i'}) + def testHasTablePrivilege(self): can = self.db.has_table_privilege self.assertEqual(can('test'), True) @@ -1918,6 +1965,32 @@ def testInsertIntoView(self): r = query(q).getresult() self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')]) + def testInsertWithGeneratedColumns(self): + insert = self.db.insert + get = self.db.get + server_version = self.db.server_version + table = 'insert_test_table_2' + table_def = 'i int not null' + if server_version >= 100000: + table_def += ( + ', a int generated always as identity' + ', d int generated by default as identity primary key') + else: + table_def += ', a int not null default 1, d int primary key' + if server_version >= 120000: + table_def += ', j int generated always as (i + 7) stored' + else: + table_def += ', j int not null default 42' + self.createTable(table, table_def) + i, d = 35, 1001 + j = i + 7 + r = insert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r = get(table, d) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + def testUpdate(self): update = self.db.update query = self.db.query @@ -2089,6 +2162,38 @@ def testUpdateWithQuotedNames(self): self.assertEqual(r['much space'], 7007) self.assertEqual(r['Questions?'], 'When?') + def testUpdateWithGeneratedColumns(self): + update = self.db.update + get = self.db.get + query = self.db.query + server_version = self.db.server_version + table = 'update_test_table_2' + table_def = 'i int not null' + if server_version >= 100000: + table_def += ( + ', a int generated always as identity' + ', d int generated by default as identity primary key') + else: + table_def += ', a int not null default 1, d int primary key' + if server_version >= 120000: + table_def += ', j int generated always as (i + 7) stored' + else: + table_def += ', j int not null default 42' + self.createTable(table, table_def) + i, d = 35, 1001 + j = i + 7 + r = query('insert into %s (i, d) values (%d, %d)' % (table, i, d)) + self.assertEqual(r, '1') + r = get(table, d) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r['i'] += 1 + r = update(table, r) + i += 1 + if server_version >= 120000: + j += 1 + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + def testUpsert(self): upsert = self.db.upsert query = self.db.query @@ -2349,6 +2454,37 @@ def testUpsertWithQuotedNames(self): r = query(q).getresult() self.assertEqual(r, [(31, 9009, 'No.')]) + def testUpsertWithGeneratedColumns(self): + upsert = self.db.upsert + get = self.db.get + server_version = self.db.server_version + table = 'upsert_test_table_2' + table_def = 'i int not null' + if server_version >= 100000: + table_def += ( + ', a int generated always as identity' + ', d int generated by default as identity primary key') + else: + table_def += ', a int not null default 1, d int primary key' + if server_version >= 120000: + table_def += ', j int generated always as (i + 7) stored' + else: + table_def += ', j int not null default 42' + self.createTable(table, table_def) + i, d = 35, 1001 + j = i + 7 + r = upsert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r['i'] += 1 + r = upsert(table, r) + i += 1 + if server_version >= 120000: + j += 1 + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r = get(table, d) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + def testClear(self): clear = self.db.clear f = False if pg.get_bool() else 'f' From 8a358e118717fded765d4f1633e39895b6ca26e3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 22:54:46 +0200 Subject: [PATCH 108/194] Add default typecast for sql_identifier --- docs/contents/changelog.rst | 2 ++ pg.py | 2 +- pgdb.py | 2 +- tests/test_classic_dbwrapper.py | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index a9b1b4fe..c55caaad 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -6,11 +6,13 @@ Version 5.2.5 (to be released) - This version officially supports the new Python 3.11 and PostgreSQL 15. - Two more improvements in the `inserttable()` method of the `pg` module (thanks to Justin Pryzby for this contribution): + - error handling has been improved (#72) - the method now returns the number of inserted rows (#73) - Another improvement in the `pg` module (#83): - generated columns can be requested with the `get_generated()` method - generated columns are ignored by the insert, update and upsert method +- Avoid internal query and error when casting the `sql_identifier` type (#82) Version 5.2.4 (2022-03-26) -------------------------- diff --git a/pg.py b/pg.py index 371c616b..181892ae 100644 --- a/pg.py +++ b/pg.py @@ -1066,7 +1066,7 @@ class Typecasts(dict): # (str functions are ignored but have been added for faster access) defaults = { 'char': str, 'bpchar': str, 'name': str, - 'text': str, 'varchar': str, + 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, 'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json, diff --git a/pgdb.py b/pgdb.py index 7eaf9cb0..d2d06f4d 100644 --- a/pgdb.py +++ b/pgdb.py @@ -555,7 +555,7 @@ class Typecasts(dict): # (str functions are ignored but have been added for faster access) defaults = { 'char': str, 'bpchar': str, 'name': str, - 'text': str, 'varchar': str, + 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, 'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode, diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index e97a23e2..4246a7c3 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4899,7 +4899,7 @@ def testMunging(self): else: self.assertNotIn('oid(t)', r) - def testQeryInformationSchema(self): + def testQueryInformationSchema(self): q = "column_name" if self.db.server_version < 110000: q += "::text" # old version does not have sql_identifier array From 53e8f10d2eb6d2ab4179ec4c8fee4495cbd9852d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 27 Aug 2023 00:19:54 +0200 Subject: [PATCH 109/194] Fix multiple calls of getresult() after send_query() --- docs/contents/changelog.rst | 1 + pgquery.c | 13 +++++++++---- tests/test_classic_connection.py | 11 +++++++++++ tests/test_classic_dbwrapper.py | 17 +++++++++++++++++ 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index c55caaad..57739271 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -13,6 +13,7 @@ Version 5.2.5 (to be released) - generated columns can be requested with the `get_generated()` method - generated columns are ignored by the insert, update and upsert method - Avoid internal query and error when casting the `sql_identifier` type (#82) +- Fix issue with multiple calls of `getresult()` after `send_query()` (#80) Version 5.2.4 (2022-03-26) -------------------------- diff --git a/pgquery.c b/pgquery.c index 852c848b..0d7ebc7d 100644 --- a/pgquery.c +++ b/pgquery.c @@ -139,8 +139,9 @@ _get_async_result(queryObject *self, int keep) { Py_END_ALLOW_THREADS if (!self->result) { /* end of result set, return None */ - Py_DECREF(self->pgcnx); - self->pgcnx = NULL; + self->max_row = 0; + self->num_fields = 0; + self->col_types = NULL; Py_INCREF(Py_None); return Py_None; } @@ -161,7 +162,7 @@ _get_async_result(queryObject *self, int keep) { } } else if (result == Py_None) { - /* It's would be confusing to return None here because the + /* It would be confusing to return None here because the caller has to call again until we return None. We can't just consume that final None because we don't know if there are additional statements following this one, so we return @@ -180,7 +181,12 @@ _get_async_result(queryObject *self, int keep) { Py_DECREF(self); return NULL; } + } else if (self->async == 2 && + !self->max_row && !self->num_fields && !self->col_types) { + Py_INCREF(Py_None); + return Py_None; } + /* return the query object itself as sentinel for a normal query result */ return (PyObject *)self; } @@ -722,7 +728,6 @@ query_namedresult(queryObject *self, PyObject *noargs) } if ((res_list = _get_async_result(self, 1)) == (PyObject *)self) { - res = PyObject_CallFunction(namediter, "(O)", self); if (!res) return NULL; if (PyList_Check(res)) return res; diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index a66af902..0152feb8 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -580,6 +580,17 @@ def testNamedresultAsync(self): self.assertEqual(v._fields, ('alias0',)) self.assertEqual(v.alias0, 0) self.assertIsNone(query.namedresult()) + self.assertIsNone(query.namedresult()) + + def testListFieldsAfterSecondGetResultAsync(self): + q = "select 1 as one" + query = self.c.send_query(q) + self.assertEqual(query.getresult(), [(1,)]) + self.assertEqual(query.listfields(), ('one',)) + self.assertIsNone(query.getresult()) + self.assertEqual(query.listfields(), ()) + self.assertIsNone(query.getresult()) + self.assertEqual(query.listfields(), ()) def testGet3Cols(self): q = "select 1,2,3" diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 4246a7c3..fac7e067 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -158,6 +158,23 @@ def testDeleteDb(self): self.assertRaises(pg.InternalError, db.close) del db + def testAsyncQueryBeforeDeletion(self): + db = DB() + query = db.send_query('select 1') + self.assertEqual(query.getresult(), [(1,)]) + self.assertIsNone(query.getresult()) + self.assertIsNone(query.getresult()) + del db + gc.collect() + + def testAsyncQueryAfterDeletion(self): + db = DB() + query = db.send_query('select 1') + del db + gc.collect() + self.assertIsNone(query.getresult()) + self.assertIsNone(query.getresult()) + class TestDBClassBasic(unittest.TestCase): """Test existence of the DB class wrapped pg connection methods.""" From c2a42905ecc84eefb58a0b9f4a08d4ea8839174b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 27 Aug 2023 16:02:09 +0200 Subject: [PATCH 110/194] Test both param styles with DB API 2 --- tests/test_dbapi20.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 94b7ab73..a03dca93 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -95,6 +95,19 @@ def test_percent_sign(self): cur.execute("select 'a %% sign'") self.assertEqual(cur.fetchone(), ('a % sign',)) + def test_paramstyles(self): + self.assertEqual(pgdb.paramstyle, 'pyformat') + con = self._connect() + cur = con.cursor() + # parameters can be passed as tuple + cur.execute("select %s, %s, %s", (123, 'abc', True)) + self.assertEqual(cur.fetchone(), (123, 'abc', True)) + # parameters can be passed as dict + cur.execute("select %(one)s, %(two)s, %(one)s, %(three)s", { + "one": 123, "two": "abc", "three": True + }) + self.assertEqual(cur.fetchone(), (123, 'abc', 123, True)) + def test_callproc_no_params(self): con = self._connect() cur = con.cursor() From 8699ace9aee8ff9ccba38695a757f34b2f9149ff Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 27 Aug 2023 18:07:24 +0200 Subject: [PATCH 111/194] Add test that inserttable does not miss failures --- tests/test_classic_connection.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 0152feb8..5c043701 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2185,6 +2185,20 @@ def testInsertTableBigRowSize(self): data = [(t,)] self.assertRaises(MemoryError, self.c.inserttable, 'test', data, ['t']) + def testInsertTableSmallIntOverflow(self): + rest_row = self.data[2][1:] + data = [(32000,) + rest_row] + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), data) + data = [(33000,) + rest_row] + try: + self.c.inserttable('test', data) + except ValueError as e: + self.assertIn( + 'value "33000" is out of range for type smallint', str(e)) + else: + self.assertFalse('expected an error') + class TestDirectSocketAccess(unittest.TestCase): """Test copy command with direct socket access.""" From e167bbcc32a5bf70865f31bdcd20cb1f6c827ece Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 28 Aug 2023 12:43:31 +0200 Subject: [PATCH 112/194] Prepare patch release --- .bumpversion.cfg | 2 +- docs/about.txt | 2 +- docs/announce.rst | 4 ++-- docs/conf.py | 2 +- docs/contents/changelog.rst | 4 ++-- setup.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 31f5835a..3a654eda 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 5.2.4 +current_version = 5.2.5 commit = False tag = False diff --git a/docs/about.txt b/docs/about.txt index d1492061..c472a304 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -36,7 +36,7 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.2.4 needs PostgreSQL 9.0 to 9.6 or 10 to 15, and +The current version PyGreSQL 5.2.5 needs PostgreSQL 9.0 to 9.6 or 10 to 15, and Python 2.7 or 3.5 to 3.11. If you need to support older PostgreSQL versions or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index cadf376b..a95cb949 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -3,10 +3,10 @@ PyGreSQL Announcements ====================== --------------------------------- -Release of PyGreSQL version 5.2.4 +Release of PyGreSQL version 5.2.5 --------------------------------- -Release 5.2.4 of PyGreSQL. +Release 5.2.5 of PyGreSQL. It is available at: https://pypi.org/project/PyGreSQL/. diff --git a/docs/conf.py b/docs/conf.py index 6a9f87e0..1e9e4113 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,7 +68,7 @@ # built documents. # # The full version, including alpha/beta/rc tags. -version = release = '5.2.4' +version = release = '5.2.5' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 57739271..bc8322f4 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,8 +1,8 @@ ChangeLog ========= -Version 5.2.5 (to be released) ------------------------------- +Version 5.2.5 (2023-08-28) +-------------------------- - This version officially supports the new Python 3.11 and PostgreSQL 15. - Two more improvements in the `inserttable()` method of the `pg` module (thanks to Justin Pryzby for this contribution): diff --git a/setup.py b/setup.py index cdb20c4f..89e07f94 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # # Please see the LICENSE.TXT file for specific restrictions. -"""Setup script for PyGreSQL version 5.2.4 +"""Setup script for PyGreSQL version 5.2.5 PyGreSQL is an open-source Python module that interfaces to a PostgreSQL database. It wraps the lower level C API library libpq @@ -52,7 +52,7 @@ from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib -version = '5.2.4' +version = '5.2.5' if not (sys.version_info[:2] == (2, 7) or (3, 5) <= sys.version_info[:2] < (4, 0)): From f7bc4a36aa83b3e9ac2a6ae7d75ca2261bdbc8e3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 28 Aug 2023 13:00:04 +0200 Subject: [PATCH 113/194] Fix tests for PostgreSQL < 9.5 --- tests/test_classic_dbwrapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index fac7e067..fa3cc655 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -1526,7 +1526,7 @@ def testGetGenerated(self): def testGetGeneratedIsCached(self): server_version = self.db.server_version if server_version < 100000: - return + self.skipTest("database does not support generated columns") get_generated = self.db.get_generated query = self.db.query table = 'test_get_generated_2' @@ -2472,6 +2472,8 @@ def testUpsertWithQuotedNames(self): self.assertEqual(r, [(31, 9009, 'No.')]) def testUpsertWithGeneratedColumns(self): + if self.db.server_version < 90500: + self.skipTest('database does not support upsert') upsert = self.db.upsert get = self.db.get server_version = self.db.server_version From fd8748d5c35c606759e3caac1003947a46f5bd7e Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 28 Aug 2023 14:52:55 +0200 Subject: [PATCH 114/194] Add .readthedocs.yaml file --- .readthedocs.yaml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .readthedocs.yaml diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..9712e405 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,22 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# We recommend specifying your dependencies to enable reproducible builds: +# https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/requirements.txt From c5f1e58bc1cc2eeb7c6db6c0e61760b768815845 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 28 Aug 2023 15:25:08 +0200 Subject: [PATCH 115/194] Do not use custom domain --- .github/workflows/docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index ec18c7ba..358659a1 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -33,6 +33,6 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} publish_branch: gh-pages publish_dir: docs/_build/html - cname: pygresql.org + # cname: pygresql.org enable_jekyll: false force_orphan: true From f3683fc3b39e534557317912863c0294a6c0b09a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 28 Aug 2023 16:27:35 +0200 Subject: [PATCH 116/194] Update Sphinx, get rid of outdated cloud theme --- .github/workflows/docs.yml | 5 +- MANIFEST.in | 3 +- README.rst | 7 +- docs/.gitignore | 1 - docs/Makefile | 198 ++--------------------- docs/_static/pygresql.css_t | 86 ---------- docs/_templates/layout.html | 58 ------- docs/community/source.rst | 6 +- docs/conf.py | 307 +++--------------------------------- docs/{toc.txt => index.rst} | 4 +- docs/make.bat | 250 ++--------------------------- docs/requirements.txt | 3 +- docs/start.txt | 15 -- tox.ini | 3 +- 14 files changed, 57 insertions(+), 889 deletions(-) delete mode 100644 docs/.gitignore delete mode 100644 docs/_static/pygresql.css_t delete mode 100644 docs/_templates/layout.html rename docs/{toc.txt => index.rst} (64%) delete mode 100644 docs/start.txt diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 358659a1..5a9ef894 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -3,7 +3,7 @@ name: Release PyGreSQL documentation on: push: branches: - - master + - main jobs: build: @@ -21,8 +21,7 @@ jobs: sudo apt install libpq-dev python -m pip install --upgrade pip pip install . - pip install "sphinx>=4.5,<5" - pip install "cloud_sptheme>=1.10,<2" + pip install "sphinx>=7,<8" - name: Create docs with Sphinx run: | cd docs diff --git a/MANIFEST.in b/MANIFEST.in index 239841c7..9b263981 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -20,5 +20,4 @@ exclude docs/index.rst recursive-include docs/community *.rst recursive-include docs/contents *.rst recursive-include docs/download *.rst -recursive-include docs/_static *.css_t *.ico *.png -recursive-include docs/_templates *.html +recursive-include docs/_static *.ico *.png diff --git a/README.rst b/README.rst index 98bb30bb..a6054363 100644 --- a/README.rst +++ b/README.rst @@ -24,6 +24,7 @@ see the documentation. Documentation ------------- -The documentation is available at `pygresql.org `_. - -At mirror of the documentation can be found at `pygresql.readthedocs.io `_. +The documentation is available at +`pygresql.github.io/PyGreSQL/ `_ +and at `pygresql.readthedocs.io `_, +where you can also find the documentation for older versions. diff --git a/docs/.gitignore b/docs/.gitignore deleted file mode 100644 index 4a579446..00000000 --- a/docs/.gitignore +++ /dev/null @@ -1 +0,0 @@ -index.rst \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile index 0a1113c9..d4bb2cbb 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,192 +1,20 @@ -# Makefile for Sphinx documentation +# Minimal makefile for Sphinx documentation # -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . BUILDDIR = _build -# User-friendly check for sphinx-build -ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) -$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) -endif - -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . -# the i18n builder cannot share the environment and doctrees with the others -I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . - -.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext - +# Put it first so that "make" without argument is like "make help". help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " singlehtml to make a single large HTML file" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " applehelp to make an Apple Help Book" - @echo " devhelp to make HTML files and a Devhelp project" - @echo " epub to make an epub" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " latexpdf to make LaTeX files and run them through pdflatex" - @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" - @echo " text to make text files" - @echo " man to make manual pages" - @echo " texinfo to make Texinfo files" - @echo " info to make Texinfo files and run them through makeinfo" - @echo " gettext to make PO message catalogs" - @echo " changes to make an overview of all changed/added/deprecated items" - @echo " xml to make Docutils-native XML files" - @echo " pseudoxml to make pseudoxml-XML files for display purposes" - @echo " linkcheck to check all external links for integrity" - @echo " doctest to run all doctests embedded in the documentation (if enabled)" - @echo " coverage to run coverage check of the documentation (if enabled)" - -clean: - rm -rf $(BUILDDIR)/* - -html: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." - -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." - -singlehtml: - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml - @echo - @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." - -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json - @echo - @echo "Build finished; now you can process the JSON files." - -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/PyGreSQL.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/PyGreSQL.qhc" - -applehelp: - $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp - @echo - @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." - @echo "N.B. You won't be able to view it unless you put it in" \ - "~/Library/Documentation/Help or install it in your application" \ - "bundle." - -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/PyGreSQL" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/PyGreSQL" - @echo "# devhelp" - -epub: - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub - @echo - @echo "Build finished. The epub file is in $(BUILDDIR)/epub." - -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo - @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." - @echo "Run \`make' in that directory to run these through (pdf)latex" \ - "(use \`make latexpdf' here to do that automatically)." - -latexpdf: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through pdflatex..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -latexpdfja: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through platex and dvipdfmx..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -text: - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text - @echo - @echo "Build finished. The text files are in $(BUILDDIR)/text." - -man: - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man - @echo - @echo "Build finished. The manual pages are in $(BUILDDIR)/man." - -texinfo: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo - @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." - @echo "Run \`make' in that directory to run these through makeinfo" \ - "(use \`make info' here to do that automatically)." - -info: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo "Running Texinfo files through makeinfo..." - make -C $(BUILDDIR)/texinfo info - @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." - -gettext: - $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale - @echo - @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." - -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes - @echo - @echo "The overview file is in $(BUILDDIR)/changes." - -linkcheck: - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in $(BUILDDIR)/linkcheck/output.txt." - -doctest: - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." - -coverage: - $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage - @echo "Testing of coverage in the sources finished, look at the " \ - "results in $(BUILDDIR)/coverage/python.txt." + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -xml: - $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml - @echo - @echo "Build finished. The XML files are in $(BUILDDIR)/xml." +.PHONY: help Makefile -pseudoxml: - $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml - @echo - @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/pygresql.css_t b/docs/_static/pygresql.css_t deleted file mode 100644 index a3bc4de2..00000000 --- a/docs/_static/pygresql.css_t +++ /dev/null @@ -1,86 +0,0 @@ -{% macro experimental(keyword, value) %} - {% if value %} - -moz-{{keyword}}: {{value}}; - -webkit-{{keyword}}: {{value}}; - -o-{{keyword}}: {{value}}; - -ms-{{keyword}}: {{value}}; - {{keyword}}: {{value}}; - {% endif %} -{% endmacro %} - -{% macro border_radius(value) -%} - {{experimental("border-radius", value)}} -{% endmacro %} - -{% macro box_shadow(value) -%} - {{experimental("box-shadow", value)}} -{% endmacro %} - -.pageheader.related { - text-align: left; - padding: 10px 15px; - border: 1px solid #eeeeee; - margin-bottom: 10px; - {{border_radius("1em 1em 1em 1em")}} - {% if theme_borderless_decor | tobool %} - border-top: 0; - border-bottom: 0; - {% endif %} -} - -.pageheader.related .logo { - font-size: 36px; - font-style: italic; - letter-spacing: 5px; - margin-right: 2em; -} - -.pageheader.related .logo { - font-size: 36px; - font-style: italic; - letter-spacing: 5px; - margin-right: 2em; -} - -.pageheader.related .logo a, .pageheader.related .logo a:hover { - background: transparent; - color: {{ theme_relbarlinkcolor }}; - border: none; - text-decoration: none; - text-shadow: none; - {{box_shadow("none")}} -} - -.pageheader.related ul { - float: right; - margin: 2px 1em; -} - -.pageheader.related li { - float: left; - margin: 0 0 0 10px; -} - -.pageheader.related li a { - padding: 8px 12px; -} - -.norelbar .subtitle { - font-size: 14px; - line-height: 18px; - font-weight: bold; - letter-spacing: 4px; - text-align: right; - padding: 0 1em; - margin-top: -9px; -} - -.relbar-top .related.norelbar { - height: 22px; - border-bottom: 14px solid #eeeeee; -} - -.relbar-bottom .related.norelbar { - height: 22px; - border-top: 14px solid #eeeeee; -} diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html deleted file mode 100644 index 1cb2ddee..00000000 --- a/docs/_templates/layout.html +++ /dev/null @@ -1,58 +0,0 @@ -{%- extends "cloud/layout.html" %} - -{% set css_files = css_files + ["_static/pygresql.css"] %} - -{# - This layout adds a page header above the standard layout. - It also removes the relbars from all pages that are not part - of the core documentation in the contents/ directory, - adapting the navigation bar (breadcrumb) appropriately. -#} - -{% set is_content = pagename.startswith(('contents/', 'genindex', 'modindex', 'py-', 'search')) %} -{% if is_content %} -{% set master_doc = 'contents/index' %} -{% set parents = parents[1:] %} -{% endif %} - -{% block header %} - - - -{% endblock %} - -{% block relbar1 -%} -{%- if is_content -%} - {{ super() }} -{% else %} -
-{%- endif -%} -{%- endblock %} - -{% block relbar2 -%} -{%- if is_content -%} - {{ super() }} -{%- else -%} -
-{%- endif -%} -{%- endblock %} - -{% block content -%} -{%- if is_content -%} -{{ super() }} -{%- else -%} -
{{ super() }}
-{%- endif -%} -{%- endblock %} diff --git a/docs/community/source.rst b/docs/community/source.rst index 224985fd..497f6280 100644 --- a/docs/community/source.rst +++ b/docs/community/source.rst @@ -4,12 +4,12 @@ Access to the source repository The source code of PyGreSQL is available as a `Git `_ repository on `GitHub `_. -The current master branch of the repository can be cloned with the command:: +The current main branch of the repository can be cloned with the command:: git clone https://github.com/PyGreSQL/PyGreSQL.git -You can also download the master branch as a -`zip archive `_. +You can also download the main branch as a +`zip archive `_. Contributions can be proposed as `pull requests `_ on GitHub. diff --git a/docs/conf.py b/docs/conf.py index 1e9e4113..933c4e38 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,90 +1,26 @@ -# -*- coding: utf-8 -*- +# Configuration file for the Sphinx documentation builder. # -# PyGreSQL documentation build configuration file. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -import sys -import os -import shlex -import shutil - -# Import Cloud theme (this will also automatically add the theme directory). -# Note: We add a navigation bar to the cloud them using a custom layout. -if os.environ.get('READTHEDOCS', None) == 'True': - # We cannot use our custom layout here, since RTD overrides layout.html. - use_cloud_theme = False -else: - try: - import cloud_sptheme - use_cloud_theme = True - except ImportError: - use_cloud_theme = False +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html -shutil.copyfile('start.txt' if use_cloud_theme else 'toc.txt', 'index.rst') - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc'] +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] if use_cloud_theme else [] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# source_suffix = ['.rst', '.md'] -source_suffix = '.rst' - -# The encoding of source files. -#source_encoding = 'utf-8-sig' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. project = 'PyGreSQL' author = 'The PyGreSQL team' copyright = '2023, ' + author -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The full version, including alpha/beta/rc tags. version = release = '5.2.5' -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = None +language = 'en' -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -exclude_patterns = ['_build'] +extensions = ['sphinx.ext.autodoc'] + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # List of pages which are included in other pages and therefore should # not appear in the toctree. @@ -93,8 +29,6 @@ 'community/mailinglist.rst', 'community/source.rst', 'community/issues.rst', 'community/support.rst', 'community/homes.rst'] -if use_cloud_theme: - exclude_patterns += ['about.rst'] # ignore certain warnings # (references to some of the Python names do not resolve correctly) @@ -102,13 +36,14 @@ nitpick_ignore = [ ('py:' + t, n) for t, names in { 'attr': ('arraysize', 'error', 'sqlstate', 'DatabaseError.sqlstate'), - 'class': ('bool', 'bytes', 'callable', 'class', + 'class': ('bool', 'bytes', 'callable', 'callables', 'class', 'dict', 'float', 'function', 'int', 'iterable', 'list', 'object', 'set', 'str', 'tuple', 'False', 'True', 'None', - 'namedtuple', 'OrderedDict', 'decimal.Decimal', + 'namedtuple', 'namedtuples', + 'OrderedDict', 'decimal.Decimal', 'bytes/str', 'list of namedtuples', 'tuple of callables', - 'type of first field', + 'first field', 'type of first field', 'Notice', 'DATETIME'), 'data': ('defbase', 'defhost', 'defopt', 'defpasswd', 'defport', 'defuser'), @@ -125,217 +60,15 @@ 'obj': ('False', 'True', 'None') }.items() for n in names] -# The reST default role (used for this markup: `text`) for all documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -# A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] - - -# If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -html_theme = 'cloud' if use_cloud_theme else 'default' - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -if use_cloud_theme: - html_theme_options = { - 'roottarget': 'contents/index', - 'defaultcollapsed': True, - 'shaded_decor': True} -else: - html_theme_options = {} - -# Add any paths that contain custom themes here, relative to this directory. -html_theme_path = ['_themes'] +html_theme = 'alabaster' +html_static_path = ['_static'] -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". html_title = 'PyGreSQL %s' % version -if use_cloud_theme: - html_title += ' documentation' -# A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. html_logo = '_static/pygresql.png' - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. html_favicon = '_static/favicon.ico' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# Add any extra paths that contain custom files (such as robots.txt or -# .htaccess) here, relative to this directory. These files are copied -# directly to the root of the documentation. -#html_extra_path = [] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -#html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_domain_indices = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None - -# Language to be used for generating the HTML full-text search index. -# Sphinx supports the following languages: -# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' -# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' -#html_search_language = 'en' - -# A dictionary with options for the search language support, empty by default. -# Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} - -# The name of a javascript file (relative to the configuration directory) that -# implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' - -# Output file base name for HTML help builder. -htmlhelp_basename = 'PyGreSQLdoc' - - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', - -# Latex figure (float) alignment -#'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'PyGreSQL.tex', 'PyGreSQL Documentation', - author, 'manual'), -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -#latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -#latex_use_parts = False - -# If true, show page references after internal links. -#latex_show_pagerefs = False - -# If true, show URL addresses after external links. -#latex_show_urls = False - -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_domain_indices = True - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'pygresql', 'PyGreSQL Documentation', [author], 1) -] - -# If true, show URL addresses after external links. -#man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - (master_doc, 'PyGreSQL', u'PyGreSQL Documentation', - author, 'PyGreSQL', 'One line description of project.', - 'Miscellaneous'), -] - -# Documents to append as an appendix to all manuals. -#texinfo_appendices = [] - -# If false, no module index is generated. -#texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' - -# If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False diff --git a/docs/toc.txt b/docs/index.rst similarity index 64% rename from docs/toc.txt rename to docs/index.rst index 441021b4..c40103a8 100644 --- a/docs/toc.txt +++ b/docs/index.rst @@ -1,5 +1,3 @@ -.. PyGreSQL index page with toc (for use without cloud theme) - Welcome to PyGreSQL =================== @@ -11,4 +9,4 @@ Welcome to PyGreSQL announce download/index contents/index - community/index \ No newline at end of file + community/index diff --git a/docs/make.bat b/docs/make.bat index b8571b60..954237b9 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -1,62 +1,16 @@ @ECHO OFF +pushd %~dp0 + REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) +set SOURCEDIR=. set BUILDDIR=_build -set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . -set I18NSPHINXOPTS=%SPHINXOPTS% . -if NOT "%PAPER%" == "" ( - set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% - set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% -) - -if "%1" == "" goto help - -if "%1" == "help" ( - :help - echo.Please use `make ^` where ^ is one of - echo. html to make standalone HTML files - echo. dirhtml to make HTML files named index.html in directories - echo. singlehtml to make a single large HTML file - echo. pickle to make pickle files - echo. json to make JSON files - echo. htmlhelp to make HTML files and a HTML help project - echo. qthelp to make HTML files and a qthelp project - echo. devhelp to make HTML files and a Devhelp project - echo. epub to make an epub - echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter - echo. text to make text files - echo. man to make manual pages - echo. texinfo to make Texinfo files - echo. gettext to make PO message catalogs - echo. changes to make an overview over all changed/added/deprecated items - echo. xml to make Docutils-native XML files - echo. pseudoxml to make pseudoxml-XML files for display purposes - echo. linkcheck to check all external links for integrity - echo. doctest to run all doctests embedded in the documentation if enabled - echo. coverage to run coverage check of the documentation if enabled - goto end -) - -if "%1" == "clean" ( - for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i - del /q /s %BUILDDIR%\* - goto end -) - - -REM Check if sphinx-build is available and fallback to Python version if any -%SPHINXBUILD% 1>NUL 2>NUL -if errorlevel 9009 goto sphinx_python -goto sphinx_ok -:sphinx_python - -set SPHINXBUILD=python -m sphinx.__init__ -%SPHINXBUILD% 2> nul +%SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx @@ -65,199 +19,17 @@ if errorlevel 9009 ( echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ + echo.https://www.sphinx-doc.org/ exit /b 1 ) -:sphinx_ok - - -if "%1" == "html" ( - %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/html. - goto end -) - -if "%1" == "dirhtml" ( - %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. - goto end -) - -if "%1" == "singlehtml" ( - %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. - goto end -) - -if "%1" == "pickle" ( - %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the pickle files. - goto end -) - -if "%1" == "json" ( - %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the JSON files. - goto end -) - -if "%1" == "htmlhelp" ( - %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run HTML Help Workshop with the ^ -.hhp project file in %BUILDDIR%/htmlhelp. - goto end -) - -if "%1" == "qthelp" ( - %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run "qcollectiongenerator" with the ^ -.qhcp project file in %BUILDDIR%/qthelp, like this: - echo.^> qcollectiongenerator %BUILDDIR%\qthelp\PyGreSQL.qhcp - echo.To view the help file: - echo.^> assistant -collectionFile %BUILDDIR%\qthelp\PyGreSQL.ghc - goto end -) - -if "%1" == "devhelp" ( - %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. - goto end -) - -if "%1" == "epub" ( - %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The epub file is in %BUILDDIR%/epub. - goto end -) - -if "%1" == "latex" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdf" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf - cd %~dp0 - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdfja" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf-ja - cd %~dp0 - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "text" ( - %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The text files are in %BUILDDIR%/text. - goto end -) - -if "%1" == "man" ( - %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The manual pages are in %BUILDDIR%/man. - goto end -) - -if "%1" == "texinfo" ( - %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. - goto end -) - -if "%1" == "gettext" ( - %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The message catalogs are in %BUILDDIR%/locale. - goto end -) - -if "%1" == "changes" ( - %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes - if errorlevel 1 exit /b 1 - echo. - echo.The overview file is in %BUILDDIR%/changes. - goto end -) - -if "%1" == "linkcheck" ( - %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck - if errorlevel 1 exit /b 1 - echo. - echo.Link check complete; look for any errors in the above output ^ -or in %BUILDDIR%/linkcheck/output.txt. - goto end -) - -if "%1" == "doctest" ( - %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest - if errorlevel 1 exit /b 1 - echo. - echo.Testing of doctests in the sources finished, look at the ^ -results in %BUILDDIR%/doctest/output.txt. - goto end -) +if "%1" == "" goto help -if "%1" == "coverage" ( - %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage - if errorlevel 1 exit /b 1 - echo. - echo.Testing of coverage in the sources finished, look at the ^ -results in %BUILDDIR%/coverage/python.txt. - goto end -) +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end -if "%1" == "xml" ( - %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The XML files are in %BUILDDIR%/xml. - goto end -) - -if "%1" == "pseudoxml" ( - %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. - goto end -) +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt index a59b8f44..9cd8b2f5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1 @@ -sphinx>=4.4,<5 -cloud_sptheme>=1.10,<2 +sphinx>=7,<8 diff --git a/docs/start.txt b/docs/start.txt deleted file mode 100644 index 5166896a..00000000 --- a/docs/start.txt +++ /dev/null @@ -1,15 +0,0 @@ -.. PyGreSQL index page without toc (for use with cloud theme) - -Welcome to PyGreSQL -=================== - -.. toctree:: - :hidden: - - copyright - announce - download/index - contents/index - community/index - -.. include:: about.txt \ No newline at end of file diff --git a/tox.ini b/tox.ini index 917e22c0..1199be00 100644 --- a/tox.ini +++ b/tox.ini @@ -12,8 +12,7 @@ commands = [testenv:docs] basepython = python3.11 deps = - sphinx>=4.5,<5 - cloud_sptheme>=1.10,<2 + sphinx>=7,<8 commands = sphinx-build -b html -nEW docs docs/_build/html From 89eaebaeecd6a94ed532d1eb04f32227cb6f018d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 14:02:07 +0200 Subject: [PATCH 117/194] Add dev container for VS Code Also start de-supporting older Python versions. --- .devcontainer/devcontainer.json | 63 +++++++++++++++++++++++ .devcontainer/provision.sh | 82 ++++++++++++++++++++++++++++++ .github/workflows/docs.yml | 52 +++++++++---------- .github/workflows/lint.yml | 6 ++- .github/workflows/tests.yml | 31 ++++++----- .vscode/settings.json | 6 +++ MANIFEST.in | 1 - tests/config.py | 21 +++++--- tests/test_classic_notification.py | 8 --- tox.ini | 4 +- 10 files changed, 212 insertions(+), 62 deletions(-) create mode 100644 .devcontainer/devcontainer.json create mode 100644 .devcontainer/provision.sh create mode 100644 .vscode/settings.json diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..c1374910 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,63 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/ubuntu +{ + "name": "PyGreSQL", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "dockerComposeFile": "docker-compose.yml", + "service": "dev", + "workspaceFolder": "/workspace", + "customizations": { + "vscode": { + // Set *default* container specific settings.json values on container create. + "settings": { + "terminal.integrated.profiles.linux": { + "bash": { + "path": "/bin/bash" + } + }, + "sqltools.connections": [ + { + "name": "Container database", + "driver": "PostgreSQL", + "previewLimit": 50, + "server": "pg15", + "port": 5432, + "database": "test", + "username": "test", + "password": "test" + } + ], + "python.pythonPath": "/usr/local/bin/python", + "python.analysis.typeCheckingMode": "basic", + "python.testing.unittestEnabled": true, + "editor.formatOnSave": true, + "editor.renderWhitespace": "all", + "editor.rulers": [ + 79 + ] + }, + // Add the IDs of extensions you want installed when the container is created. + "extensions": [ + "ms-azuretools.vscode-docker", + "ms-python.python", + "ms-vscode.cpptools", + "mtxr.sqltools", + "njpwerner.autodocstring", + "redhat.vscode-yaml", + "eamodio.gitlens", + "streetsidesoftware.code-spell-checker", + "lextudio.restructuredtext" + ] + } + }, + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": "bash /workspace/.devcontainer/provision.sh" + // Configure tool-specific properties. + // "customizations": {}, + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} \ No newline at end of file diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh new file mode 100644 index 00000000..5cea536f --- /dev/null +++ b/.devcontainer/provision.sh @@ -0,0 +1,82 @@ +#!/usr/bin/bash + +# install development environment for PyGreSQL + +export DEBIAN_FRONTEND=noninteractive + +sudo apt-get update +sudo apt-get -y upgrade + +# install base utilities and configure time zone + +sudo ln -fs /usr/share/zoneinfo/UTC /etc/localtime +sudo apt-get install -y apt-utils software-properties-common +sudo apt-get install -y tzdata +sudo dpkg-reconfigure --frontend noninteractive tzdata + +sudo apt-get install -y rpm wget zip + +# install all supported Python versions + +sudo add-apt-repository -y ppa:deadsnakes/ppa +sudo apt-get update + +sudo apt-get install -y python3.7 python3.7-dev python3.7-distutils +sudo apt-get install -y python3.8 python3.8-dev python3.8-distutils +sudo apt-get install -y python3.9 python3.9-dev python3.9-distutils +sudo apt-get install -y python3.10 python3.10-dev python3.10-distutils +sudo apt-get install -y python3.11 python3.11-dev python3.11-distutils + +# install testing tool + +sudo apt-get install -y tox + +# install PostgreSQL client tools + +sudo apt-get install -y postgresql libpq-dev + +for pghost in pg10 pg12 pg14 pg15 +do + export PGHOST=$pghost + export PGDATABASE=postgres + export PGUSER=postgres + export PGPASSWORD=postgres + + createdb -E UTF8 -T template0 test + createdb -E SQL_ASCII -T template0 test_ascii + createdb -E LATIN1 -l C -T template0 test_latin1 + createdb -E LATIN9 -l C -T template0 test_latin9 + createdb -E ISO_8859_5 -l C -T template0 test_cyrillic + + psql -c "create user test with password 'test'" + + psql -c "grant create on database test to test" + psql -c "grant create on database test_ascii to test" + psql -c "grant create on database test_latin1 to test" + psql -c "grant create on database test_latin9 to test" + psql -c "grant create on database test_cyrillic to test" + + psql -c "grant create on schema public to test" test + psql -c "grant create on schema public to test" test_ascii + psql -c "grant create on schema public to test" test_latin1 + psql -c "grant create on schema public to test" test_latin9 + psql -c "grant create on schema public to test" test_cyrillic + + psql -c "create extension hstore" test + psql -c "create extension hstore" test_ascii + psql -c "create extension hstore" test_latin1 + psql -c "create extension hstore" test_latin9 + psql -c "create extension hstore" test_cyrillic +done + +export PGHOST=pg15 +export PGPORT=5432 +export PGDATABASE=test +export PGUSER=test +export PGPASSWORD=test + +export PYGRESQL_DB=$PGDATABASE +export PYGRESQL_HOST=$PGHOST +export PYGRESQL_PORT=$PGPORT +export PYGRESQL_USER=$PGUSER +export PYGRESQL_PASSWD=$PGPASSWORD diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5a9ef894..aae221a0 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,4 +1,4 @@ -name: Release PyGreSQL documentation +name: Publish PyGreSQL documentation on: push: @@ -7,31 +7,31 @@ on: jobs: build: - runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.11 - uses: actions/setup-python@v4 - with: - python-version: 3.11 - - name: Install dependencies - run: | - sudo apt install libpq-dev - python -m pip install --upgrade pip - pip install . - pip install "sphinx>=7,<8" - - name: Create docs with Sphinx - run: | - cd docs - make html - - name: Deploy docs to GitHub pages - uses: peaceiris/actions-gh-pages@v3 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_branch: gh-pages - publish_dir: docs/_build/html - # cname: pygresql.org - enable_jekyll: false - force_orphan: true + - name: CHeck out repository + uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: 3.11 + - name: Install dependencies + run: | + sudo apt install libpq-dev + python -m pip install --upgrade pip + pip install . + pip install "sphinx>=7,<8" + - name: Create docs with Sphinx + run: | + cd docs + make html + - name: Deploy docs to GitHub pages + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_branch: gh-pages + publish_dir: docs/_build/html + cname: pygresql.org + enable_jekyll: false + force_orphan: true diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 205d8b54..54ae2fd3 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,10 +13,12 @@ jobs: fail-fast: false steps: - - uses: actions/checkout@v3 + - name: Check out repository + uses: actions/checkout@v3 - name: Install tox run: pip install tox - - uses: actions/setup-python@v4 + - name: Setup Python + uses: actions/setup-python@v4 with: python-version: 3.11 - name: Run quality checks diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 46eac7c0..ca8e4a36 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,27 +9,24 @@ on: jobs: tests: name: Unit tests run - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 strategy: fail-fast: false matrix: include: - # - {python: "2.7", postgres: "9.3"} - # - {python: "3.5", postgres: "9.6"} - # - {python: "3.6", postgres: "10"} - - {python: "3.7", postgres: "11"} - - {python: "3.8", postgres: "12"} - - {python: "3.9", postgres: "13"} - - {python: "3.10", postgres: "14"} - - {python: "3.11", postgres: "15"} + - { python: "3.7", postgres: "11" } + - { python: "3.8", postgres: "12" } + - { python: "3.9", postgres: "13" } + - { python: "3.10", postgres: "14" } + - { python: "3.11", postgres: "15" } # Opposite extremes of the supported Py/PG range, other architecture - - {python: "3.7", postgres: "15", architecture: "x86"} - - {python: "3.8", postgres: "14", architecture: "x86"} - - {python: "3.9", postgres: "13", architecture: "x86"} - - {python: "3.10", postgres: "12", architecture: "x86"} - - {python: "3.11", postgres: "11", architecture: "x86"} + - { python: "3.7", postgres: "15", architecture: "x86" } + - { python: "3.8", postgres: "14", architecture: "x86" } + - { python: "3.9", postgres: "13", architecture: "x86" } + - { python: "3.10", postgres: "12", architecture: "x86" } + - { python: "3.11", postgres: "11", architecture: "x86" } env: PYGRESQL_DB: test @@ -53,10 +50,12 @@ jobs: --health-retries 5 steps: - - uses: actions/checkout@v3 + - name: Check out repository + uses: actions/checkout@v3 - name: Install tox run: pip install tox - - uses: actions/setup-python@v4 + - name: Setup Python + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Run tests diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..9ee86e71 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.autopep8" + }, + "python.formatting.provider": "none" +} \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index 9b263981..e6e9e5a9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -9,7 +9,6 @@ include LICENSE.txt include tox.ini recursive-include tests *.py -exclude tests/LOCAL_PyGreSQL.py include docs/Makefile include docs/make.bat diff --git a/tests/config.py b/tests/config.py index a6082593..a6dcd3a3 100644 --- a/tests/config.py +++ b/tests/config.py @@ -4,8 +4,10 @@ from os import environ # We need a database to test against. -# If LOCAL_PyGreSQL.py exists, we will get our information from that. -# Otherwise, we use the defaults. + +# The connection parameters are taken from the usual PG* environment +# variables and can be overridden with PYGRESQL_* environment variables +# or values specified in the file .LOCAL_PyGreSQL or LOCAL_PyGreSQL.py. # The tests should be run with various PostgreSQL versions and databases # created with different encodings and locales. Particularly, make sure the @@ -13,11 +15,16 @@ # The current user must have create schema privilege on the database. -dbname = environ.get('PYGRESQL_DB', 'unittest') -dbhost = environ.get('PYGRESQL_HOST', None) -dbport = environ.get('PYGRESQL_PORT', 5432) -dbuser = environ.get('PYGRESQL_USER', None) -dbpasswd = environ.get('PYGRESQL_PASSWD', None) +get = environ.get + +dbname = get('PYGRESQL_DB', get('PGDATABASE')) +dbhost = get('PYGRESQL_HOST', get('PGHOST')) +dbport = get('PYGRESQL_PORT', get('PGPORT')) +dbuser = get('PYGRESQL_USER', get('PGUSER')) +dbpasswd = get('PYGRESQL_PASSWD', get('PGPASSWORD')) + +if dbport: + dbport = int(dbport) try: from .LOCAL_PyGreSQL import * # noqa: F401 diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 29e6921d..39f607df 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -21,14 +21,6 @@ debug = False # let DB wrapper print debugging output -try: - from .LOCAL_PyGreSQL import * # noqa: F401 -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * # noqa: F401 - except ImportError: - pass - def DB(): """Create a DB wrapper object connecting to the test database.""" diff --git a/tox.ini b/tox.ini index 1199be00..d48b44c7 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py27,py3{5,6,7,8,9,10,11},flake8,docs +envlist = py3{7,8,9,10,11},flake8,docs [testenv:flake8] basepython = python3.11 @@ -22,4 +22,4 @@ passenv = PYGRESQL_* commands = python setup.py clean --all build_ext --force --inplace --strict --ssl-info --memory-size - python -m unittest discover {posargs} + python -m unittest {posargs:discover} From 5771ad75d98863fe97d69af248ba14855d165a7c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:38:24 +0200 Subject: [PATCH 118/194] Properly set dev environment variables --- .devcontainer/dev.env | 11 +++++++++++ .devcontainer/provision.sh | 12 ------------ 2 files changed, 11 insertions(+), 12 deletions(-) create mode 100644 .devcontainer/dev.env diff --git a/.devcontainer/dev.env b/.devcontainer/dev.env new file mode 100644 index 00000000..996ee8d2 --- /dev/null +++ b/.devcontainer/dev.env @@ -0,0 +1,11 @@ +PGHOST=pg15 +PGPORT=5432 +PGDATABASE=test +PGUSER=test +PGPASSWORD=test + +PYGRESQL_DB=test +PYGRESQL_HOST=pg15 +PYGRESQL_PORT=5432 +PYGRESQL_USER=test +PYGRESQL_PASSWD=test diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 5cea536f..b47abb8c 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -68,15 +68,3 @@ do psql -c "create extension hstore" test_latin9 psql -c "create extension hstore" test_cyrillic done - -export PGHOST=pg15 -export PGPORT=5432 -export PGDATABASE=test -export PGUSER=test -export PGPASSWORD=test - -export PYGRESQL_DB=$PGDATABASE -export PYGRESQL_HOST=$PGHOST -export PYGRESQL_PORT=$PGPORT -export PYGRESQL_USER=$PGUSER -export PYGRESQL_PASSWD=$PGPASSWORD From 6c03eba7ebd3f998ab9f08e7a4f21171221796e1 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:40:26 +0200 Subject: [PATCH 119/194] Ignore VS Code settings --- .gitignore | 1 + .vscode/settings.json | 6 ------ 2 files changed, 1 insertion(+), 6 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 71300f9e..83732331 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ Thumbs.db .idea/ .vs/ +.vscode/ diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 9ee86e71..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "[python]": { - "editor.defaultFormatter": "ms-python.autopep8" - }, - "python.formatting.provider": "none" -} \ No newline at end of file From 97d2a258879c1fb9e22754e6ea189eac917b109c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:46:31 +0200 Subject: [PATCH 120/194] Start desupporting old Python and Postgres versions --- docs/about.txt | 7 +++-- docs/announce.rst | 17 +++++------- docs/contents/install.rst | 2 +- docs/contents/pg/adaptation.rst | 4 +-- docs/contents/pg/connection.rst | 2 +- docs/contents/pgdb/adaptation.rst | 4 +-- pg.py | 13 ++------- pgdb.py | 14 +++------- setup.py | 45 ++++++++----------------------- tests/test_classic_connection.py | 3 +-- tests/test_classic_dbwrapper.py | 3 --- 11 files changed, 32 insertions(+), 82 deletions(-) diff --git a/docs/about.txt b/docs/about.txt index c472a304..04f615e1 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -36,7 +36,6 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.2.5 needs PostgreSQL 9.0 to 9.6 or 10 to 15, and -Python 2.7 or 3.5 to 3.11. If you need to support older PostgreSQL versions or -older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that -still support them. +The current version PyGreSQL 6.0 needs PostgreSQL 10 to 15, and Python +3.7 to 3.11. If you need to support older PostgreSQL or Python versions, +you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index a95cb949..d0a5f19c 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -2,11 +2,11 @@ PyGreSQL Announcements ====================== ---------------------------------- -Release of PyGreSQL version 5.2.5 ---------------------------------- +------------------------------- +Release of PyGreSQL version 6.0 +------------------------------- -Release 5.2.5 of PyGreSQL. +Release 6.0 of PyGreSQL. It is available at: https://pypi.org/project/PyGreSQL/. @@ -17,13 +17,10 @@ Please refer to `changelog.txt `_ for things that have changed in this version. This version has been built and unit tested on: - - NetBSD - - FreeBSD - - openSUSE - Ubuntu - - Windows 7 and 10 with both MinGW and Visual Studio - - PostgreSQL 9.0 to 9.6 and 10 to 15 (32 and 64bit) - - Python 2.7 and 3.5 to 3.11 (32 and 64bit) + - Windows 7 and 10 with Visual Studio + - PostgreSQL 10 to 15 (32 and 64bit) + - Python 3.7 to 3.11 (32 and 64bit) | D'Arcy J.M. Cain | darcy@PyGreSQL.org diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 4ef323af..d1926881 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -14,7 +14,7 @@ On Windows, you also need to make sure that the directory that contains ``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -2.7 and 3.5 to 3.11, and PostgreSQL versions 9.0 to 9.6 and 10 to 15. +3.7 to 3.11, and PostgreSQL versions 10 to 15. PyGreSQL will be installed as three modules, a shared library called ``_pg.so`` (on Linux) or a DLL called ``_pg.pyd`` (on Windows), and two pure diff --git a/docs/contents/pg/adaptation.rst b/docs/contents/pg/adaptation.rst index 1cf44418..c5d0a795 100644 --- a/docs/contents/pg/adaptation.rst +++ b/docs/contents/pg/adaptation.rst @@ -26,7 +26,7 @@ PostgreSQL Python char, bpchar, name, text, varchar str bool bool bytea bytes -int2, int4, int8, oid, serial int [#int8]_ +int2, int4, int8, oid, serial int int2vector list of int float4, float8 float numeric, money Decimal @@ -45,8 +45,6 @@ record tuple Elements of arrays and records will also be converted accordingly. - .. [#int8] int8 is converted to long in Python 2 - .. [#array] The first element of the array will always be the first element of the Python list, no matter what the lower bound of the PostgreSQL array is. The information about the start index of the array (which is diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index c95adf59..d1c95213 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -730,7 +730,7 @@ the connection and its status. These attributes are: .. attribute:: Connection.server_version - the backend version (int, e.g. 90305 for 9.3.5) + the backend version (int, e.g. 150400 for 15.4) .. versionadded:: 4.0 diff --git a/docs/contents/pgdb/adaptation.rst b/docs/contents/pgdb/adaptation.rst index 0f9ad5a6..ebb36e5b 100644 --- a/docs/contents/pgdb/adaptation.rst +++ b/docs/contents/pgdb/adaptation.rst @@ -26,7 +26,7 @@ PostgreSQL Python char, bpchar, name, text, varchar str bool bool bytea bytes -int2, int4, int8, oid, serial int [#int8]_ +int2, int4, int8, oid, serial int int2vector list of int float4, float8 float numeric, money Decimal @@ -45,8 +45,6 @@ record tuple Elements of arrays and records will also be converted accordingly. - .. [#int8] int8 is converted to long in Python 2 - .. [#array] The first element of the array will always be the first element of the Python list, no matter what the lower bound of the PostgreSQL array is. The information about the start index of the array (which is diff --git a/pg.py b/pg.py index 181892ae..9d5a7e13 100644 --- a/pg.py +++ b/pg.py @@ -96,13 +96,7 @@ from re import compile as regex from json import loads as jsondecode, dumps as jsonencode from uuid import UUID - -try: - # noinspection PyUnresolvedReferences - from typing import Dict, List, Union # noqa: F401 - has_typing = True -except ImportError: # Python < 3.5 - has_typing = False +from typing import Dict, List, Union # noqa: F401 try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable long @@ -342,9 +336,6 @@ class _SimpleTypes(dict): bytes, unicode, basestring] } # type: Dict[str, List[Union[str, type]]] - if long is not int: # Python 2 has a separate long type - _type_aliases['num'].append(long) - # noinspection PyMissingConstructor def __init__(self): """Initialize type mapping.""" @@ -354,7 +345,7 @@ def __init__(self): self[key] = typ if isinstance(key, str): self['_%s' % key] = '%s[]' % typ - elif has_typing and not isinstance(key, tuple): + elif not isinstance(key, tuple): self[List[key]] = '%s[]' % typ @staticmethod diff --git a/pgdb.py b/pgdb.py index d2d06f4d..ccf848e9 100644 --- a/pgdb.py +++ b/pgdb.py @@ -1008,10 +1008,7 @@ def _quote(self, value): if not value: # exception for empty array return "'{}'" q = self._quote - try: - return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),) - except UnicodeEncodeError: # Python 2 with non-ascii values - return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),) + return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),) if isinstance(value, tuple): # Quote as a ROW constructor. This is better than using a record # literal because it carries the information that this is a record @@ -1019,10 +1016,7 @@ def _quote(self, value): # this usable with the IN syntax as well. It is only necessary # when the records has a single column which is not really useful. q = self._quote - try: - return '(%s)' % (','.join(str(q(v)) for v in value),) - except UnicodeEncodeError: # Python 2 with non-ascii values - return u'(%s)' % (','.join(unicode(q(v)) for v in value),) + return '(%s)' % (','.join(str(q(v)) for v in value),) try: # noinspection PyUnresolvedReferences value = value.__pg_repr__() except AttributeError: @@ -1472,8 +1466,8 @@ def __next__(self): raise StopIteration return res - # Note that since Python 2.6 the iterator protocol uses __next()__ - # instead of next(), we keep it only for backward compatibility of pgdb. + # Note that the iterator protocol now uses __next()__ instead of next(), + # but we keep it for backward compatibility of pgdb. next = __next__ @staticmethod diff --git a/setup.py b/setup.py index 89e07f94..fb5330e8 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # # Please see the LICENSE.TXT file for specific restrictions. -"""Setup script for PyGreSQL version 5.2.5 +"""Setup script for PyGreSQL version 6.0 PyGreSQL is an open-source Python module that interfaces to a PostgreSQL database. It wraps the lower level C API library libpq @@ -26,8 +26,8 @@ * PostgreSQL pg_config tool (usually included in the devel package) (the Windows installer has it as part of the database server feature) -PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.11, -and PostgreSQL versions 9.0 to 9.6 and 10 to 15. +PyGreSQL currently supports Python versions 3.7 to 3.11, +and PostgreSQL versions 10 to 15. Use as follows: python setup.py build_ext # to build the module @@ -52,10 +52,9 @@ from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib -version = '5.2.5' +version = '6.0' -if not (sys.version_info[:2] == (2, 7) - or (3, 5) <= sys.version_info[:2] < (4, 0)): +if not (3, 7) <= sys.version_info[:2] < (4, 0): raise Exception( "Sorry, PyGreSQL %s does not support this Python version" % version) @@ -84,7 +83,7 @@ def pg_version(): match = re.search(r'(\d+)\.(\d+)', pg_config('version')) if match: return tuple(map(int, match.groups())) - return 9, 0 + return 10, 0 pg_version = pg_version() @@ -146,7 +145,7 @@ def initialize_options(self): self.pqlib_info = None self.ssl_info = None self.memory_size = None - supported = pg_version >= (9, 0) + supported = pg_version >= (10, 0) if not supported: warnings.warn( "PyGreSQL does not support the installed PostgreSQL version.") @@ -162,33 +161,15 @@ def finalize_options(self): define_macros.append(('LARGE_OBJECTS', None)) if self.default_vars is None or self.default_vars: define_macros.append(('DEFAULT_VARS', None)) - wanted = self.escaping_funcs - supported = pg_version >= (9, 0) - if wanted or (wanted is None and supported): + if self.escaping_funcs is None or self.escaping_funcs: define_macros.append(('ESCAPING_FUNCS', None)) - if not supported: - warnings.warn( - "The installed PostgreSQL version" - " does not support the newer string escaping functions.") - wanted = self.pqlib_info - supported = pg_version >= (9, 1) - if wanted or (wanted is None and supported): + if self.pqlib_info is None or self.pqlib_info: define_macros.append(('PQLIB_INFO', None)) - if not supported: - warnings.warn( - "The installed PostgreSQL version" - " does not support PQLib info functions.") - wanted = self.ssl_info - supported = pg_version >= (9, 5) - if wanted or (wanted is None and supported): + if self.ssl_info is None or self.ssl_info: define_macros.append(('SSL_INFO', None)) - if not supported: - warnings.warn( - "The installed PostgreSQL version" - " does not support SSL info functions.") wanted = self.memory_size supported = pg_version >= (12, 0) - if wanted or (wanted is None and supported): + if (wanted is None and supported) or wanted: define_macros.append(('MEMORY_SIZE', None)) if not supported: warnings.warn( @@ -243,11 +224,7 @@ def finalize_options(self): "Operating System :: OS Independent", "Programming Language :: C", 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 5c043701..c43c2101 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -297,8 +297,7 @@ def testAllQueryMembers(self): members.remove('memsize') query_members = [ a for a in dir(query) - if not a.startswith('__') - and a != 'next'] # this is only needed in Python 2 + if not a.startswith('__')] self.assertEqual(members, query_members) def testMethodEndcopy(self): diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index fa3cc655..fd09c9d5 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -5015,9 +5015,6 @@ def getLeaks(self, fut): gc.collect() objs[:] = gc.get_objects() objs[:] = [obj for obj in objs if id(obj) not in ids] - if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)): - # workaround for Python 3.5 issue 26811 - objs[:] = [obj for obj in objs if repr(obj) != '(,)'] self.assertEqual(len(objs), 0) def testLeaksWithClose(self): From db5374c893b3e9368e11caf2e1058d834ce1eca3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:01:01 +0000 Subject: [PATCH 121/194] Remove now unnecessary py3c shim --- docs/download/files.rst | 1 - pgconn.c | 58 ++++++++-------- pginternal.c | 56 +++++++--------- pglarge.c | 14 ++-- pgmodule.c | 75 ++++++++++----------- pgnotice.c | 6 +- pgquery.c | 27 ++++---- pgsource.c | 55 +++++++--------- py3c.h | 143 ---------------------------------------- 9 files changed, 134 insertions(+), 301 deletions(-) delete mode 100644 py3c.h diff --git a/docs/download/files.rst b/docs/download/files.rst index 4f4741fd..ec581bf0 100644 --- a/docs/download/files.rst +++ b/docs/download/files.rst @@ -12,7 +12,6 @@ pgquery.c the query object pgsource.c the source object pgtypes.h PostgreSQL type definitions -py3c.h Python 2/3 compatibility layer for the C extension pg.py the "classic" PyGreSQL module pgdb.py a DB-SIG DB-API 2.0 compliant API wrapper for PyGreSQL diff --git a/pgconn.c b/pgconn.c index d39c9301..910f2212 100644 --- a/pgconn.c +++ b/pgconn.c @@ -26,7 +26,7 @@ conn_dealloc(connObject *self) static PyObject * conn_getattr(connObject *self, PyObject *nameobj) { - const char *name = PyStr_AsString(nameobj); + const char *name = PyUnicode_AsUTF8(nameobj); /* * Although we could check individually, there are only a few @@ -47,49 +47,49 @@ conn_getattr(connObject *self, PyObject *nameobj) char *r = PQhost(self->cnx); if (!r || r[0] == '/') /* Pg >= 9.6 can return a Unix socket path */ r = "localhost"; - return PyStr_FromString(r); + return PyUnicode_FromString(r); } /* postmaster port */ if (!strcmp(name, "port")) - return PyInt_FromLong(atol(PQport(self->cnx))); + return PyLong_FromLong(atol(PQport(self->cnx))); /* selected database */ if (!strcmp(name, "db")) - return PyStr_FromString(PQdb(self->cnx)); + return PyUnicode_FromString(PQdb(self->cnx)); /* selected options */ if (!strcmp(name, "options")) - return PyStr_FromString(PQoptions(self->cnx)); + return PyUnicode_FromString(PQoptions(self->cnx)); /* error (status) message */ if (!strcmp(name, "error")) - return PyStr_FromString(PQerrorMessage(self->cnx)); + return PyUnicode_FromString(PQerrorMessage(self->cnx)); /* connection status : 1 - OK, 0 - BAD */ if (!strcmp(name, "status")) - return PyInt_FromLong(PQstatus(self->cnx) == CONNECTION_OK ? 1 : 0); + return PyLong_FromLong(PQstatus(self->cnx) == CONNECTION_OK ? 1 : 0); /* provided user name */ if (!strcmp(name, "user")) - return PyStr_FromString(PQuser(self->cnx)); + return PyUnicode_FromString(PQuser(self->cnx)); /* protocol version */ if (!strcmp(name, "protocol_version")) - return PyInt_FromLong(PQprotocolVersion(self->cnx)); + return PyLong_FromLong(PQprotocolVersion(self->cnx)); /* backend version */ if (!strcmp(name, "server_version")) - return PyInt_FromLong(PQserverVersion(self->cnx)); + return PyLong_FromLong(PQserverVersion(self->cnx)); /* descriptor number of connection socket */ if (!strcmp(name, "socket")) { - return PyInt_FromLong(PQsocket(self->cnx)); + return PyLong_FromLong(PQsocket(self->cnx)); } /* PID of backend process */ if (!strcmp(name, "backend_pid")) { - return PyInt_FromLong(PQbackendPID(self->cnx)); + return PyLong_FromLong(PQbackendPID(self->cnx)); } /* whether the connection uses SSL */ @@ -183,7 +183,7 @@ _conn_non_query_result(int status, PGresult* result, PGconn *cnx) char *ret = PQcmdTuples(result); if (ret[0]) { /* return number of rows affected */ - PyObject *obj = PyStr_FromString(ret); + PyObject *obj = PyUnicode_FromString(ret); PQclear(result); return obj; } @@ -193,7 +193,7 @@ _conn_non_query_result(int status, PGresult* result, PGconn *cnx) } /* for a single insert, return the oid */ PQclear(result); - return PyInt_FromLong((long) oid); + return PyLong_FromLong((long) oid); } case PGRES_COPY_OUT: /* no data will be received */ case PGRES_COPY_IN: @@ -325,7 +325,7 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) return NULL; } *s++ = str_obj; - *p = PyStr_AsString(str_obj); + *p = PyUnicode_AsUTF8(str_obj); } } @@ -614,7 +614,7 @@ conn_getline(connObject *self, PyObject *noargs) } /* for backward compatibility, convert terminating newline to zero byte */ if (*line) line[strlen(line) - 1] = '\0'; - str = PyStr_FromString(line); + str = PyUnicode_FromString(line); PQfreemem(line); return str; } @@ -947,9 +947,9 @@ conn_inserttable(connObject *self, PyObject *args) Py_DECREF(s); } } - else if (PyInt_Check(item) || PyLong_Check(item)) { + else if (PyLong_Check(item)) { PyObject* s = PyObject_Str(item); - const char* t = PyStr_AsString(s); + const char* t = PyUnicode_AsUTF8(s); while (*t && bufsiz) { *bufpt++ = *t++; --bufsiz; @@ -958,7 +958,7 @@ conn_inserttable(connObject *self, PyObject *args) } else { PyObject* s = PyObject_Repr(item); - const char* t = PyStr_AsString(s); + const char* t = PyUnicode_AsUTF8(s); while (*t && bufsiz) { switch (*t) { @@ -1036,7 +1036,7 @@ conn_inserttable(connObject *self, PyObject *args) } else { long ntuples = atol(PQcmdTuples(result)); PQclear(result); - return PyInt_FromLong(ntuples); + return PyLong_FromLong(ntuples); } } @@ -1052,7 +1052,7 @@ conn_transaction(connObject *self, PyObject *noargs) return NULL; } - return PyInt_FromLong(PQtransactionStatus(self->cnx)); + return PyLong_FromLong(PQtransactionStatus(self->cnx)); } /* Get parameter setting. */ @@ -1079,7 +1079,7 @@ conn_parameter(connObject *self, PyObject *args) name = PQparameterStatus(self->cnx, name); if (name) - return PyStr_FromString(name); + return PyUnicode_FromString(name); /* unknown parameter, return None */ Py_INCREF(Py_None); @@ -1107,7 +1107,7 @@ conn_date_format(connObject *self, PyObject *noargs) self->date_format = fmt; /* cache the result */ } - return PyStr_FromString(fmt); + return PyUnicode_FromString(fmt); } #ifdef ESCAPING_FUNCS @@ -1450,7 +1450,7 @@ conn_cancel(connObject *self, PyObject *noargs) } /* request that the server abandon processing of the current command */ - return PyInt_FromLong((long) PQrequestCancel(self->cnx)); + return PyLong_FromLong((long) PQrequestCancel(self->cnx)); } /* Get connection socket. */ @@ -1465,7 +1465,7 @@ conn_fileno(connObject *self, PyObject *noargs) return NULL; } - return PyInt_FromLong((long) PQsocket(self->cnx)); + return PyLong_FromLong((long) PQsocket(self->cnx)); } /* Set external typecast callback function. */ @@ -1536,7 +1536,7 @@ conn_poll(connObject *self, PyObject *noargs) return NULL; } - return PyInt_FromLong(rc); + return PyLong_FromLong(rc); } /* Set notice receiver callback function. */ @@ -1632,7 +1632,7 @@ conn_get_notify(connObject *self, PyObject *noargs) else { PyObject *notify_result, *tmp; - if (!(tmp = PyStr_FromString(notify->relname))) { + if (!(tmp = PyUnicode_FromString(notify->relname))) { return NULL; } @@ -1642,7 +1642,7 @@ conn_get_notify(connObject *self, PyObject *noargs) PyTuple_SET_ITEM(notify_result, 0, tmp); - if (!(tmp = PyInt_FromLong(notify->be_pid))) { + if (!(tmp = PyLong_FromLong(notify->be_pid))) { Py_DECREF(notify_result); return NULL; } @@ -1650,7 +1650,7 @@ conn_get_notify(connObject *self, PyObject *noargs) PyTuple_SET_ITEM(notify_result, 1, tmp); /* extra exists even in old versions that did not support it */ - if (!(tmp = PyStr_FromString(notify->extra))) { + if (!(tmp = PyUnicode_FromString(notify->extra))) { Py_DECREF(notify_result); return NULL; } diff --git a/pginternal.c b/pginternal.c index 6dcad8bc..50181b0d 100644 --- a/pginternal.c +++ b/pginternal.c @@ -247,11 +247,10 @@ cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) break; default: /* PYGRES_TEXT */ -#if IS_PY3 obj = get_decoded_string(s, size, encoding); - if (!obj) /* cannot decode */ -#endif - obj = PyBytes_FromStringAndSize(s, size); + if (!obj) { /* cannot decode */ + obj = PyBytes_FromStringAndSize(s, size); + } } return obj; @@ -296,7 +295,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) *t++ = *s++; } *t = '\0'; - obj = PyInt_FromString(buf, NULL, 10); + obj = PyLong_FromString(buf, NULL, 10); break; case PYGRES_LONG: @@ -312,7 +311,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) break; case PYGRES_FLOAT: - tmp_obj = PyStr_FromStringAndSize(s, size); + tmp_obj = PyUnicode_FromStringAndSize(s, size); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); break; @@ -336,7 +335,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) obj = PyObject_CallFunction(decimal, "(s)", buf); } else { - tmp_obj = PyStr_FromString(buf); + tmp_obj = PyUnicode_FromString(buf); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); @@ -344,7 +343,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) break; case PYGRES_DECIMAL: - tmp_obj = PyStr_FromStringAndSize(s, size); + tmp_obj = PyUnicode_FromStringAndSize(s, size); obj = decimal ? PyObject_CallFunctionObjArgs( decimal, tmp_obj, NULL) : PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); @@ -353,7 +352,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) case PYGRES_BOOL: /* convert to bool only if bool_as_text is not set */ if (bool_as_text) { - obj = PyStr_FromString(*s == 't' ? "t" : "f"); + obj = PyUnicode_FromString(*s == 't' ? "t" : "f"); } else { obj = *s == 't' ? Py_True : Py_False; @@ -363,7 +362,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) default: /* other types should never be passed, use cast_sized_text */ - obj = PyStr_FromStringAndSize(s, size); + obj = PyUnicode_FromStringAndSize(s, size); } return obj; @@ -381,15 +380,12 @@ cast_unsized_simple(char *s, int type) switch (type) { /* this must be the PyGreSQL internal type */ case PYGRES_INT: - obj = PyInt_FromString(s, NULL, 10); - break; - case PYGRES_LONG: obj = PyLong_FromString(s, NULL, 10); break; case PYGRES_FLOAT: - tmp_obj = PyStr_FromString(s); + tmp_obj = PyUnicode_FromString(s); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); break; @@ -416,7 +412,7 @@ cast_unsized_simple(char *s, int type) obj = PyObject_CallFunction(decimal, "(s)", s); } else { - tmp_obj = PyStr_FromString(s); + tmp_obj = PyUnicode_FromString(s); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); } @@ -425,7 +421,7 @@ cast_unsized_simple(char *s, int type) case PYGRES_BOOL: /* convert to bool only if bool_as_text is not set */ if (bool_as_text) { - obj = PyStr_FromString(*s == 't' ? "t" : "f"); + obj = PyUnicode_FromString(*s == 't' ? "t" : "f"); } else { obj = *s == 't' ? Py_True : Py_False; @@ -435,7 +431,7 @@ cast_unsized_simple(char *s, int type) default: /* other types should never be passed, use cast_sized_text */ - obj = PyStr_FromString(s); + obj = PyUnicode_FromString(s); } return obj; @@ -613,12 +609,11 @@ cast_array(char *s, Py_ssize_t size, int encoding, element = cast_sized_simple(estr, esize, type); } else { /* external casting of base type */ -#if IS_PY3 element = encoding == pg_encoding_ascii ? NULL : get_decoded_string(estr, esize, encoding); - if (!element) /* no decoding necessary or possible */ -#endif - element = PyBytes_FromStringAndSize(estr, esize); + if (!element) { /* no decoding necessary or possible */ + element = PyBytes_FromStringAndSize(estr, esize); + } if (element && cast) { PyObject *tmp = element; element = PyObject_CallFunctionObjArgs( @@ -768,12 +763,11 @@ cast_record(char *s, Py_ssize_t size, int encoding, element = cast_sized_simple(estr, esize, etype); } else { /* external casting of base type */ -#if IS_PY3 element = encoding == pg_encoding_ascii ? NULL : get_decoded_string(estr, esize, encoding); - if (!element) /* no decoding necessary or possible */ -#endif - element = PyBytes_FromStringAndSize(estr, esize); + if (!element) { /* no decoding necessary or possible */ + element = PyBytes_FromStringAndSize(estr, esize); + } if (element && cast) { if (len) { PyObject *ecast = PySequence_GetItem(cast, i); @@ -1065,17 +1059,15 @@ set_error_msg_and_state(PyObject *type, { PyObject *err_obj, *msg_obj, *sql_obj = NULL; -#if IS_PY3 if (encoding == -1) /* unknown */ msg_obj = PyUnicode_DecodeLocale(msg, NULL); else msg_obj = get_decoded_string(msg, (Py_ssize_t) strlen(msg), encoding); if (!msg_obj) /* cannot decode */ -#endif - msg_obj = PyBytes_FromString(msg); + msg_obj = PyBytes_FromString(msg); if (sqlstate) { - sql_obj = PyStr_FromStringAndSize(sqlstate, 5); + sql_obj = PyUnicode_FromStringAndSize(sqlstate, 5); } else { Py_INCREF(Py_None); sql_obj = Py_None; @@ -1139,7 +1131,7 @@ get_ssl_attributes(PGconn *cnx) { const char *val = PQsslAttribute(cnx, *s); if (val) { - PyObject * val_obj = PyStr_FromString(val); + PyObject * val_obj = PyUnicode_FromString(val); PyDict_SetItemString(attr_dict, *s, val_obj); Py_DECREF(val_obj); @@ -1280,7 +1272,7 @@ format_result(const PGresult *res) /* create the footer */ sprintf(p, "(%d row%s)", m, m == 1 ? "" : "s"); /* return the result */ - result = PyStr_FromString(buffer); + result = PyUnicode_FromString(buffer); PyMem_Free(buffer); return result; } @@ -1293,7 +1285,7 @@ format_result(const PGresult *res) } } else - return PyStr_FromString("(nothing selected)"); + return PyUnicode_FromString("(nothing selected)"); } /* Internal function converting a Postgres datestyles to date formats. */ diff --git a/pglarge.c b/pglarge.c index c080d658..863e2ec9 100644 --- a/pglarge.c +++ b/pglarge.c @@ -31,7 +31,7 @@ large_str(largeObject *self) sprintf(str, self->lo_fd >= 0 ? "Opened large object, oid %ld" : "Closed large object, oid %ld", (long) self->lo_oid); - return PyStr_FromString(str); + return PyUnicode_FromString(str); } /* Check validity of large object. */ @@ -67,7 +67,7 @@ _check_lo_obj(largeObject *self, int level) static PyObject * large_getattr(largeObject *self, PyObject *nameobj) { - const char *name = PyStr_AsString(nameobj); + const char *name = PyUnicode_AsUTF8(nameobj); /* list postgreSQL large object fields */ @@ -85,7 +85,7 @@ large_getattr(largeObject *self, PyObject *nameobj) /* large object oid */ if (!strcmp(name, "oid")) { if (_check_lo_obj(self, 0)) - return PyInt_FromLong((long) self->lo_oid); + return PyLong_FromLong((long) self->lo_oid); PyErr_Clear(); Py_INCREF(Py_None); return Py_None; @@ -93,7 +93,7 @@ large_getattr(largeObject *self, PyObject *nameobj) /* error (status) message */ if (!strcmp(name, "error")) - return PyStr_FromString(PQerrorMessage(self->pgcnx->cnx)); + return PyUnicode_FromString(PQerrorMessage(self->pgcnx->cnx)); /* seeks name in methods (fallback) */ return PyObject_GenericGetAttr((PyObject *) self, nameobj); @@ -285,7 +285,7 @@ large_seek(largeObject *self, PyObject *args) } /* returns position */ - return PyInt_FromLong(ret); + return PyLong_FromLong(ret); } /* Get large object size. */ @@ -325,7 +325,7 @@ large_size(largeObject *self, PyObject *noargs) } /* returns size */ - return PyInt_FromLong(end); + return PyLong_FromLong(end); } /* Get large object cursor position. */ @@ -350,7 +350,7 @@ large_tell(largeObject *self, PyObject *noargs) } /* returns size */ - return PyInt_FromLong(start); + return PyLong_FromLong(start); } /* Export large object as unix file. */ diff --git a/pgmodule.c b/pgmodule.c index bbb4b0db..6adc79c0 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -19,9 +19,6 @@ /* The type definitions from */ #include "pgtypes.h" -/* Macros for single-source Python 2/3 compatibility */ -#include "py3c.h" - static PyObject *Error, *Warning, *InterfaceError, *DatabaseError, *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, *DataError, *NotSupportedError, @@ -237,7 +234,7 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) pghost = PyBytes_AsString(pg_default_host); if ((pgport == -1) && (pg_default_port != Py_None)) - pgport = (int) PyInt_AsLong(pg_default_port); + pgport = (int) PyLong_AsLong(pg_default_port); if ((!pgopt) && (pg_default_opt != Py_None)) pgopt = PyBytes_AsString(pg_default_opt); @@ -488,7 +485,7 @@ static PyObject * pg_get_datestyle(PyObject *self, PyObject *noargs) { if (date_format) { - return PyStr_FromString(date_format_to_style(date_format)); + return PyUnicode_FromString(date_format_to_style(date_format)); } else { Py_INCREF(Py_None); return Py_None; @@ -507,7 +504,7 @@ pg_get_decimal_point(PyObject *self, PyObject *noargs) if (decimal_point) { s[0] = decimal_point; s[1] = '\0'; - ret = PyStr_FromString(s); + ret = PyUnicode_FromString(s); } else { Py_INCREF(Py_None); ret = Py_None; @@ -804,7 +801,7 @@ pg_set_defhost(PyObject *self, PyObject *args) old = pg_default_host; if (tmp) { - pg_default_host = PyStr_FromString(tmp); + pg_default_host = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -847,7 +844,7 @@ pg_set_defbase(PyObject *self, PyObject *args) old = pg_default_base; if (tmp) { - pg_default_base = PyStr_FromString(tmp); + pg_default_base = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -890,7 +887,7 @@ pg_setdefopt(PyObject *self, PyObject *args) old = pg_default_opt; if (tmp) { - pg_default_opt = PyStr_FromString(tmp); + pg_default_opt = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -934,7 +931,7 @@ pg_set_defuser(PyObject *self, PyObject *args) old = pg_default_user; if (tmp) { - pg_default_user = PyStr_FromString(tmp); + pg_default_user = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -962,7 +959,7 @@ pg_set_defpasswd(PyObject *self, PyObject *args) } if (tmp) { - pg_default_passwd = PyStr_FromString(tmp); + pg_default_passwd = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -1006,7 +1003,7 @@ pg_set_defport(PyObject *self, PyObject *args) old = pg_default_port; if (port != -1) { - pg_default_port = PyInt_FromLong(port); + pg_default_port = PyLong_FromLong(port); } else { Py_INCREF(Py_None); @@ -1250,7 +1247,9 @@ static struct PyModuleDef moduleDef = { }; /* Initialization function for the module */ -MODULE_INIT_FUNC(_pg) +PyMODINIT_FUNC PyInit__pg(void); + +PyMODINIT_FUNC PyInit__pg(void) { PyObject *mod, *dict, *s; @@ -1259,18 +1258,10 @@ MODULE_INIT_FUNC(_pg) mod = PyModule_Create(&moduleDef); /* Initialize here because some Windows platforms get confused otherwise */ -#if IS_PY3 connType.tp_base = noticeType.tp_base = queryType.tp_base = sourceType.tp_base = &PyBaseObject_Type; #ifdef LARGE_OBJECTS largeType.tp_base = &PyBaseObject_Type; -#endif -#else - connType.ob_type = noticeType.ob_type = - queryType.ob_type = sourceType.ob_type = &PyType_Type; -#ifdef LARGE_OBJECTS - largeType.ob_type = &PyType_Type; -#endif #endif if (PyType_Ready(&connType) @@ -1288,10 +1279,10 @@ MODULE_INIT_FUNC(_pg) dict = PyModule_GetDict(mod); /* Exceptions as defined by DB-API 2.0 */ - Error = PyErr_NewException("pg.Error", PyExc_StandardError, NULL); + Error = PyErr_NewException("pg.Error", PyExc_Exception, NULL); PyDict_SetItemString(dict, "Error", Error); - Warning = PyErr_NewException("pg.Warning", PyExc_StandardError, NULL); + Warning = PyErr_NewException("pg.Warning", PyExc_Exception, NULL); PyDict_SetItemString(dict, "Warning", Warning); InterfaceError = PyErr_NewException( @@ -1339,39 +1330,39 @@ MODULE_INIT_FUNC(_pg) PyDict_SetItemString(dict, "MultipleResultsError", MultipleResultsError); /* Make the version available */ - s = PyStr_FromString(PyPgVersion); + s = PyUnicode_FromString(PyPgVersion); PyDict_SetItemString(dict, "version", s); PyDict_SetItemString(dict, "__version__", s); Py_DECREF(s); /* Result types for queries */ - PyDict_SetItemString(dict, "RESULT_EMPTY", PyInt_FromLong(RESULT_EMPTY)); - PyDict_SetItemString(dict, "RESULT_DML", PyInt_FromLong(RESULT_DML)); - PyDict_SetItemString(dict, "RESULT_DDL", PyInt_FromLong(RESULT_DDL)); - PyDict_SetItemString(dict, "RESULT_DQL", PyInt_FromLong(RESULT_DQL)); + PyDict_SetItemString(dict, "RESULT_EMPTY", PyLong_FromLong(RESULT_EMPTY)); + PyDict_SetItemString(dict, "RESULT_DML", PyLong_FromLong(RESULT_DML)); + PyDict_SetItemString(dict, "RESULT_DDL", PyLong_FromLong(RESULT_DDL)); + PyDict_SetItemString(dict, "RESULT_DQL", PyLong_FromLong(RESULT_DQL)); /* Transaction states */ - PyDict_SetItemString(dict, "TRANS_IDLE", PyInt_FromLong(PQTRANS_IDLE)); - PyDict_SetItemString(dict, "TRANS_ACTIVE", PyInt_FromLong(PQTRANS_ACTIVE)); - PyDict_SetItemString(dict, "TRANS_INTRANS", PyInt_FromLong(PQTRANS_INTRANS)); - PyDict_SetItemString(dict, "TRANS_INERROR", PyInt_FromLong(PQTRANS_INERROR)); - PyDict_SetItemString(dict, "TRANS_UNKNOWN", PyInt_FromLong(PQTRANS_UNKNOWN)); + PyDict_SetItemString(dict, "TRANS_IDLE", PyLong_FromLong(PQTRANS_IDLE)); + PyDict_SetItemString(dict, "TRANS_ACTIVE", PyLong_FromLong(PQTRANS_ACTIVE)); + PyDict_SetItemString(dict, "TRANS_INTRANS", PyLong_FromLong(PQTRANS_INTRANS)); + PyDict_SetItemString(dict, "TRANS_INERROR", PyLong_FromLong(PQTRANS_INERROR)); + PyDict_SetItemString(dict, "TRANS_UNKNOWN", PyLong_FromLong(PQTRANS_UNKNOWN)); /* Polling results */ - PyDict_SetItemString(dict, "POLLING_OK", PyInt_FromLong(PGRES_POLLING_OK)); - PyDict_SetItemString(dict, "POLLING_FAILED", PyInt_FromLong(PGRES_POLLING_FAILED)); - PyDict_SetItemString(dict, "POLLING_READING", PyInt_FromLong(PGRES_POLLING_READING)); - PyDict_SetItemString(dict, "POLLING_WRITING", PyInt_FromLong(PGRES_POLLING_WRITING)); + PyDict_SetItemString(dict, "POLLING_OK", PyLong_FromLong(PGRES_POLLING_OK)); + PyDict_SetItemString(dict, "POLLING_FAILED", PyLong_FromLong(PGRES_POLLING_FAILED)); + PyDict_SetItemString(dict, "POLLING_READING", PyLong_FromLong(PGRES_POLLING_READING)); + PyDict_SetItemString(dict, "POLLING_WRITING", PyLong_FromLong(PGRES_POLLING_WRITING)); #ifdef LARGE_OBJECTS /* Create mode for large objects */ - PyDict_SetItemString(dict, "INV_READ", PyInt_FromLong(INV_READ)); - PyDict_SetItemString(dict, "INV_WRITE", PyInt_FromLong(INV_WRITE)); + PyDict_SetItemString(dict, "INV_READ", PyLong_FromLong(INV_READ)); + PyDict_SetItemString(dict, "INV_WRITE", PyLong_FromLong(INV_WRITE)); /* Position flags for lo_lseek */ - PyDict_SetItemString(dict, "SEEK_SET", PyInt_FromLong(SEEK_SET)); - PyDict_SetItemString(dict, "SEEK_CUR", PyInt_FromLong(SEEK_CUR)); - PyDict_SetItemString(dict, "SEEK_END", PyInt_FromLong(SEEK_END)); + PyDict_SetItemString(dict, "SEEK_SET", PyLong_FromLong(SEEK_SET)); + PyDict_SetItemString(dict, "SEEK_CUR", PyLong_FromLong(SEEK_CUR)); + PyDict_SetItemString(dict, "SEEK_END", PyLong_FromLong(SEEK_END)); #endif /* LARGE_OBJECTS */ #ifdef DEFAULT_VARS diff --git a/pgnotice.c b/pgnotice.c index ae6b2b68..e079283c 100644 --- a/pgnotice.c +++ b/pgnotice.c @@ -13,7 +13,7 @@ static PyObject * notice_getattr(noticeObject *self, PyObject *nameobj) { PGresult const *res = self->res; - const char *name = PyStr_AsString(nameobj); + const char *name = PyUnicode_AsUTF8(nameobj); int fieldcode; if (!res) { @@ -35,7 +35,7 @@ notice_getattr(noticeObject *self, PyObject *nameobj) /* full message */ if (!strcmp(name, "message")) { - return PyStr_FromString(PQresultErrorMessage(res)); + return PyUnicode_FromString(PQresultErrorMessage(res)); } /* other possible fields */ @@ -51,7 +51,7 @@ notice_getattr(noticeObject *self, PyObject *nameobj) if (fieldcode) { char *s = PQresultErrorField(res, fieldcode); if (s) { - return PyStr_FromString(s); + return PyUnicode_FromString(s); } else { Py_INCREF(Py_None); return Py_None; diff --git a/pgquery.c b/pgquery.c index 0d7ebc7d..0923eb66 100644 --- a/pgquery.c +++ b/pgquery.c @@ -168,7 +168,7 @@ _get_async_result(queryObject *self, int keep) { are additional statements following this one, so we return an empty string where query() would return None. */ Py_DECREF(result); - result = PyStr_FromString(""); + result = PyUnicode_FromString(""); } return result; } @@ -266,7 +266,7 @@ static char query_ntuples__doc__[] = static PyObject * query_ntuples(queryObject *self, PyObject *noargs) { - return PyInt_FromLong(self->max_row); + return PyLong_FromLong(self->max_row); } /* List field names from query result. */ @@ -285,7 +285,7 @@ query_listfields(queryObject *self, PyObject *noargs) if (fieldstuple) { for (i = 0; i < self->num_fields; ++i) { name = PQfname(self->result, i); - str = PyStr_FromString(name); + str = PyUnicode_FromString(name); PyTuple_SET_ITEM(fieldstuple, i, str); } } @@ -317,7 +317,7 @@ query_fieldname(queryObject *self, PyObject *args) /* gets fields name and builds object */ name = PQfname(self->result, i); - return PyStr_FromString(name); + return PyUnicode_FromString(name); } /* Get field number from name in last result. */ @@ -343,7 +343,7 @@ query_fieldnum(queryObject *self, PyObject *args) return NULL; } - return PyInt_FromLong(num); + return PyLong_FromLong(num); } /* Build a tuple with info for query field with given number. */ @@ -353,10 +353,10 @@ _query_build_field_info(PGresult *res, int col_num) { info = PyTuple_New(4); if (info) { - PyTuple_SET_ITEM(info, 0, PyStr_FromString(PQfname(res, col_num))); - PyTuple_SET_ITEM(info, 1, PyInt_FromLong((long) PQftype(res, col_num))); - PyTuple_SET_ITEM(info, 2, PyInt_FromLong(PQfsize(res, col_num))); - PyTuple_SET_ITEM(info, 3, PyInt_FromLong(PQfmod(res, col_num))); + PyTuple_SET_ITEM(info, 0, PyUnicode_FromString(PQfname(res, col_num))); + PyTuple_SET_ITEM(info, 1, PyLong_FromLong((long) PQftype(res, col_num))); + PyTuple_SET_ITEM(info, 2, PyLong_FromLong(PQfsize(res, col_num))); + PyTuple_SET_ITEM(info, 3, PyLong_FromLong(PQfmod(res, col_num))); } return info; } @@ -383,13 +383,13 @@ query_fieldinfo(queryObject *self, PyObject *args) /* gets field number */ if (PyBytes_Check(field)) { num = PQfnumber(self->result, PyBytes_AsString(field)); - } else if (PyStr_Check(field)) { + } else if (PyUnicode_Check(field)) { PyObject *tmp = get_encoded_string(field, self->encoding); if (!tmp) return NULL; num = PQfnumber(self->result, PyBytes_AsString(tmp)); Py_DECREF(tmp); - } else if (PyInt_Check(field)) { - num = (int) PyInt_AsLong(field); + } else if (PyLong_Check(field)) { + num = (int) PyLong_AsLong(field); } else { PyErr_SetString(PyExc_TypeError, "Field should be given as column number or name"); @@ -980,8 +980,7 @@ static PyTypeObject queryType = { PyObject_GenericGetAttr, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT - |Py_TPFLAGS_HAVE_ITER, /* tp_flags */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ query__doc__, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ diff --git a/pgsource.c b/pgsource.c index 053ad02f..7b081273 100644 --- a/pgsource.c +++ b/pgsource.c @@ -28,10 +28,10 @@ source_str(sourceObject *self) return format_result(self->result); case RESULT_DDL: case RESULT_DML: - return PyStr_FromString(PQcmdStatus(self->result)); + return PyUnicode_FromString(PQcmdStatus(self->result)); case RESULT_EMPTY: default: - return PyStr_FromString("(empty PostgreSQL source object)"); + return PyUnicode_FromString("(empty PostgreSQL source object)"); } } @@ -65,7 +65,7 @@ _check_source_obj(sourceObject *self, int level) static PyObject * source_getattr(sourceObject *self, PyObject *nameobj) { - const char *name = PyStr_AsString(nameobj); + const char *name = PyUnicode_AsUTF8(nameobj); /* pg connection object */ if (!strcmp(name, "pgcnx")) { @@ -79,19 +79,19 @@ source_getattr(sourceObject *self, PyObject *nameobj) /* arraysize */ if (!strcmp(name, "arraysize")) - return PyInt_FromLong(self->arraysize); + return PyLong_FromLong(self->arraysize); /* resulttype */ if (!strcmp(name, "resulttype")) - return PyInt_FromLong(self->result_type); + return PyLong_FromLong(self->result_type); /* ntuples */ if (!strcmp(name, "ntuples")) - return PyInt_FromLong(self->max_row); + return PyLong_FromLong(self->max_row); /* nfields */ if (!strcmp(name, "nfields")) - return PyInt_FromLong(self->num_fields); + return PyLong_FromLong(self->num_fields); /* seeks name in methods (fallback) */ return PyObject_GenericGetAttr((PyObject *) self, nameobj); @@ -103,12 +103,12 @@ source_setattr(sourceObject *self, char *name, PyObject *v) { /* arraysize */ if (!strcmp(name, "arraysize")) { - if (!PyInt_Check(v)) { + if (!PyLong_Check(v)) { PyErr_SetString(PyExc_TypeError, "arraysize must be integer"); return -1; } - self->arraysize = PyInt_AsLong(v); + self->arraysize = PyLong_AsLong(v); return 0; } @@ -227,7 +227,7 @@ source_execute(sourceObject *self, PyObject *sql) self->result_type = RESULT_DDL; num_rows = -1; } - return PyInt_FromLong(num_rows); + return PyLong_FromLong(num_rows); } /* query failed */ @@ -272,7 +272,7 @@ source_oidstatus(sourceObject *self, PyObject *noargs) return Py_None; } - return PyInt_FromLong((long) oid); + return PyLong_FromLong((long) oid); } /* Fetch rows from last result. */ @@ -287,9 +287,7 @@ source_fetch(sourceObject *self, PyObject *args) PyObject *res_list; int i, k; long size; -#if IS_PY3 int encoding; -#endif /* checks validity */ if (!_check_source_obj(self, CHECK_RESULT | CHECK_DQL | CHECK_CNX)) { @@ -313,9 +311,7 @@ source_fetch(sourceObject *self, PyObject *args) /* allocate list for result */ if (!(res_list = PyList_New(0))) return NULL; -#if IS_PY3 encoding = self->encoding; -#endif /* builds result */ for (i = 0, k = self->current_row; i < size; ++i, ++k) { @@ -336,15 +332,14 @@ source_fetch(sourceObject *self, PyObject *args) else { char *s = PQgetvalue(self->result, k, j); Py_ssize_t size = PQgetlength(self->result, k, j); -#if IS_PY3 if (PQfformat(self->result, j) == 0) { /* textual format */ str = get_decoded_string(s, size, encoding); if (!str) /* cannot decode */ str = PyBytes_FromStringAndSize(s, size); } - else -#endif - str = PyBytes_FromStringAndSize(s, size); + else { + str = PyBytes_FromStringAndSize(s, size); + } } PyTuple_SET_ITEM(rowtuple, j, str); } @@ -531,7 +526,7 @@ source_putdata(sourceObject *self, PyObject *buffer) tmp = PQcmdTuples(result); num_rows = tmp[0] ? atol(tmp) : -1; - ret = PyInt_FromLong(num_rows); + ret = PyLong_FromLong(num_rows); } else { if (!errormsg) errormsg = PQerrorMessage(self->pgcnx->cnx); @@ -602,7 +597,7 @@ source_getdata(sourceObject *self, PyObject *args) tmp = PQcmdTuples(result); num_rows = tmp[0] ? atol(tmp) : -1; - ret = PyInt_FromLong(num_rows); + ret = PyLong_FromLong(num_rows); } else { PyErr_SetString(PyExc_IOError, PQerrorMessage(self->pgcnx->cnx)); @@ -634,11 +629,11 @@ _source_fieldindex(sourceObject *self, PyObject *param, const char *usage) return -1; /* gets field number */ - if (PyStr_Check(param)) { + if (PyUnicode_Check(param)) { num = PQfnumber(self->result, PyBytes_AsString(param)); } - else if (PyInt_Check(param)) { - num = (int) PyInt_AsLong(param); + else if (PyLong_Check(param)) { + num = (int) PyLong_AsLong(param); } else { PyErr_SetString(PyExc_TypeError, usage); @@ -667,15 +662,15 @@ _source_buildinfo(sourceObject *self, int num) } /* affects field information */ - PyTuple_SET_ITEM(result, 0, PyInt_FromLong(num)); + PyTuple_SET_ITEM(result, 0, PyLong_FromLong(num)); PyTuple_SET_ITEM(result, 1, - PyStr_FromString(PQfname(self->result, num))); + PyUnicode_FromString(PQfname(self->result, num))); PyTuple_SET_ITEM(result, 2, - PyInt_FromLong((long) PQftype(self->result, num))); + PyLong_FromLong((long) PQftype(self->result, num))); PyTuple_SET_ITEM(result, 3, - PyInt_FromLong(PQfsize(self->result, num))); + PyLong_FromLong(PQfsize(self->result, num))); PyTuple_SET_ITEM(result, 4, - PyInt_FromLong(PQfmod(self->result, num))); + PyLong_FromLong(PQfmod(self->result, num))); return result; } @@ -751,7 +746,7 @@ source_field(sourceObject *self, PyObject *desc) return NULL; } - return PyStr_FromString( + return PyUnicode_FromString( PQgetvalue(self->result, self->current_row, num)); } diff --git a/py3c.h b/py3c.h deleted file mode 100644 index c137b191..00000000 --- a/py3c.h +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright (c) 2015, Red Hat, Inc. and/or its affiliates - * Licensed under the MIT license; see py3c.h - */ - -#ifndef _PY3C_COMPAT_H_ -#define _PY3C_COMPAT_H_ -#define PY_SSIZE_T_CLEAN -#include - -#if PY_MAJOR_VERSION >= 3 - -/***** Python 3 *****/ - -#define IS_PY3 1 - -/* Strings */ - -#define PyStr_Type PyUnicode_Type -#define PyStr_Check PyUnicode_Check -#define PyStr_CheckExact PyUnicode_CheckExact -#define PyStr_FromString PyUnicode_FromString -#define PyStr_FromStringAndSize PyUnicode_FromStringAndSize -#define PyStr_FromFormat PyUnicode_FromFormat -#define PyStr_FromFormatV PyUnicode_FromFormatV -#define PyStr_AsString PyUnicode_AsUTF8 -#define PyStr_Concat PyUnicode_Concat -#define PyStr_Format PyUnicode_Format -#define PyStr_InternInPlace PyUnicode_InternInPlace -#define PyStr_InternFromString PyUnicode_InternFromString -#define PyStr_Decode PyUnicode_Decode - -#define PyStr_AsUTF8String PyUnicode_AsUTF8String // returns PyBytes -#define PyStr_AsUTF8 PyUnicode_AsUTF8 -#define PyStr_AsUTF8AndSize PyUnicode_AsUTF8AndSize - -/* Ints */ - -#define PyInt_Type PyLong_Type -#define PyInt_Check PyLong_Check -#define PyInt_CheckExact PyLong_CheckExact -#define PyInt_FromString PyLong_FromString -#define PyInt_FromLong PyLong_FromLong -#define PyInt_FromSsize_t PyLong_FromSsize_t -#define PyInt_FromSize_t PyLong_FromSize_t -#define PyInt_AsLong PyLong_AsLong -#define PyInt_AS_LONG PyLong_AS_LONG -#define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask -#define PyInt_AsSsize_t PyLong_AsSsize_t - -/* Module init */ - -#define MODULE_INIT_FUNC(name) \ - PyMODINIT_FUNC PyInit_ ## name(void); \ - PyMODINIT_FUNC PyInit_ ## name(void) - -/* Other */ - -#define Py_TPFLAGS_HAVE_ITER 0 // not needed in Python 3 - -#define PyExc_StandardError PyExc_Exception // exists only in Python 2 - -#else - -/***** Python 2 *****/ - -#define IS_PY3 0 - -/* Strings */ - -#define PyStr_Type PyString_Type -#define PyStr_Check PyString_Check -#define PyStr_CheckExact PyString_CheckExact -#define PyStr_FromString PyString_FromString -#define PyStr_FromStringAndSize PyString_FromStringAndSize -#define PyStr_FromFormat PyString_FromFormat -#define PyStr_FromFormatV PyString_FromFormatV -#define PyStr_AsString PyString_AsString -#define PyStr_Format PyString_Format -#define PyStr_InternInPlace PyString_InternInPlace -#define PyStr_InternFromString PyString_InternFromString -#define PyStr_Decode PyString_Decode - -static inline PyObject *PyStr_Concat(PyObject *left, PyObject *right) { - PyObject *str = left; - Py_INCREF(left); // reference to old left will be stolen - PyString_Concat(&str, right); - if (str) { - return str; - } else { - return NULL; - } -} - -#define PyStr_AsUTF8String(str) (Py_INCREF(str), (str)) -#define PyStr_AsUTF8 PyString_AsString -#define PyStr_AsUTF8AndSize(pystr, sizeptr) \ - ((*sizeptr=PyString_Size(pystr)), PyString_AsString(pystr)) - -#define PyBytes_Type PyString_Type -#define PyBytes_Check PyString_Check -#define PyBytes_CheckExact PyString_CheckExact -#define PyBytes_FromString PyString_FromString -#define PyBytes_FromStringAndSize PyString_FromStringAndSize -#define PyBytes_FromFormat PyString_FromFormat -#define PyBytes_FromFormatV PyString_FromFormatV -#define PyBytes_Size PyString_Size -#define PyBytes_GET_SIZE PyString_GET_SIZE -#define PyBytes_AsString PyString_AsString -#define PyBytes_AS_STRING PyString_AS_STRING -#define PyBytes_AsStringAndSize PyString_AsStringAndSize -#define PyBytes_Concat PyString_Concat -#define PyBytes_ConcatAndDel PyString_ConcatAndDel -#define _PyBytes_Resize _PyString_Resize - -/* Floats */ - -#define PyFloat_FromString(str) PyFloat_FromString(str, NULL) - -/* Module init */ - -#define PyModuleDef_HEAD_INIT 0 - -typedef struct PyModuleDef { - int m_base; - const char* m_name; - const char* m_doc; - Py_ssize_t m_size; - PyMethodDef *m_methods; -} PyModuleDef; - -#define PyModule_Create(def) \ - Py_InitModule3((def)->m_name, (def)->m_methods, (def)->m_doc) - -#define MODULE_INIT_FUNC(name) \ - static PyObject *PyInit_ ## name(void); \ - void init ## name(void); \ - void init ## name(void) { PyInit_ ## name(); } \ - static PyObject *PyInit_ ## name(void) - - -#endif - -#endif From de8936f7ba56a2468aee8f37d1cb8c7dfa00fc55 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:04:50 +0000 Subject: [PATCH 122/194] Remove now unnecessary encoding comments --- tests/config.py | 1 - tests/test_classic.py | 1 - tests/test_classic_connection.py | 1 - tests/test_classic_dbwrapper.py | 1 - tests/test_classic_functions.py | 1 - tests/test_classic_largeobj.py | 1 - tests/test_classic_notification.py | 1 - tests/test_dbapi20.py | 1 - tests/test_dbapi20_copy.py | 1 - tests/test_tutorial.py | 1 - 10 files changed, 10 deletions(-) diff --git a/tests/config.py b/tests/config.py index a6dcd3a3..e6bf326c 100644 --- a/tests/config.py +++ b/tests/config.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- from os import environ diff --git a/tests/test_classic.py b/tests/test_classic.py index 727e4a86..3284d9ee 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- from __future__ import print_function diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index c43c2101..8c7adc39 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index fd09c9d5..0843710d 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 653fbb87..db450ec8 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index fc0464d5..3271686c 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 39f607df..6f94cebd 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index a03dca93..2d853f73 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- import gc import sys diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 47fc012a..d6fd1cfc 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the modern PyGreSQL interface. diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 6f968560..d9d1398b 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- from __future__ import print_function From 0fd08bb5d9330d90a8e018f781b57c44606041ec Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:07:18 +0000 Subject: [PATCH 123/194] Remove now unnecessary future statements --- pg.py | 2 -- pgdb.py | 2 -- tests/test_classic.py | 2 -- tests/test_tutorial.py | 2 -- 4 files changed, 8 deletions(-) diff --git a/pg.py b/pg.py index 9d5a7e13..b8c4fa08 100644 --- a/pg.py +++ b/pg.py @@ -20,8 +20,6 @@ For a DB-API 2 compliant interface use the newer pgdb module. """ -from __future__ import print_function, division - try: from _pg import * except ImportError as e: diff --git a/pgdb.py b/pgdb.py index ccf848e9..2919caf3 100644 --- a/pgdb.py +++ b/pgdb.py @@ -64,8 +64,6 @@ connection.close() # close the connection """ -from __future__ import print_function, division - try: from _pg import * except ImportError as e: diff --git a/tests/test_classic.py b/tests/test_classic.py index 3284d9ee..375bad3f 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -1,7 +1,5 @@ #!/usr/bin/python -from __future__ import print_function - import unittest from functools import partial diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index d9d1398b..0193165a 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -1,7 +1,5 @@ #!/usr/bin/python -from __future__ import print_function - import unittest from pg import DB From 60f9eb4a73c4f489e47b797bdfe5a8984a1cc985 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 20:08:04 +0200 Subject: [PATCH 124/194] Simplify pg modules assuming modern Python --- pg.py | 245 ++++++++++--------------------------------------------- pgdb.py | 246 ++++++++++---------------------------------------------- 2 files changed, 84 insertions(+), 407 deletions(-) diff --git a/pg.py b/pg.py index b8c4fa08..2edeb6e2 100644 --- a/pg.py +++ b/pg.py @@ -46,10 +46,9 @@ else: libpq += 'so' if e: - # note: we could use "raise from e" here in Python 3 raise ImportError( "Cannot import shared library for PyGreSQL,\n" - "probably because no %s is installed.\n%s" % (libpq, e)) + "probably because no %s is installed.\n%s" % (libpq, e)) from e __version__ = version @@ -85,165 +84,24 @@ import warnings import weakref -from datetime import date, time, datetime, timedelta, tzinfo +from datetime import date, time, datetime, timedelta from decimal import Decimal from math import isnan, isinf from collections import namedtuple, OrderedDict +from inspect import signature from operator import itemgetter -from functools import partial +from functools import lru_cache, partial from re import compile as regex from json import loads as jsondecode, dumps as jsonencode from uuid import UUID from typing import Dict, List, Union # noqa: F401 -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - unicode -except NameError: # Python >= 3.0 - unicode = str - -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - basestring -except NameError: # Python >= 3.0 - basestring = (str, bytes) - -try: - from functools import lru_cache -except ImportError: # Python < 3.2 - from functools import update_wrapper - try: # noinspection PyCompatibility - from _thread import RLock - except ImportError: - class RLock: # for builds without threads - def __enter__(self): - pass - - def __exit__(self, exctype, excinst, exctb): - pass - - def lru_cache(maxsize=128): - """Simplified functools.lru_cache decorator for one argument.""" - - def decorator(function): - sentinel = object() - cache = {} - get = cache.get - lock = RLock() - root = [] - root_full = [root, False] - root[:] = [root, root, None, None] - - if maxsize == 0: - - def wrapper(arg): - res = function(arg) - return res - - elif maxsize is None: - - def wrapper(arg): - res = get(arg, sentinel) - if res is not sentinel: - return res - res = function(arg) - cache[arg] = res - return res - - else: - - def wrapper(arg): - with lock: - link = get(arg) - if link is not None: - root = root_full[0] - prv, nxt, _arg, res = link - prv[1] = nxt - nxt[0] = prv - last = root[0] - last[1] = root[0] = link - link[0] = last - link[1] = root - return res - res = function(arg) - with lock: - root, full = root_full - if arg in cache: - pass - elif full: - oldroot = root - oldroot[2] = arg - oldroot[3] = res - root = root_full[0] = oldroot[1] - oldarg = root[2] - oldres = root[3] # noqa F481 (keep reference) - root[2] = root[3] = None - del cache[oldarg] - cache[arg] = oldroot - else: - last = root[0] - link = [last, root, arg, res] - last[1] = root[0] = cache[arg] = link - if len(cache) >= maxsize: - root_full[1] = True - return res - - wrapper.__wrapped__ = function - return update_wrapper(wrapper, function) - - return decorator - # Auxiliary classes and functions that are independent of a DB connection: -try: # noinspection PyUnresolvedReferences - from inspect import signature -except ImportError: # Python < 3.3 - from inspect import getargspec - - def get_args(func): - return getargspec(func).args -else: - - def get_args(func): - return list(signature(func).parameters) - -try: - from datetime import timezone -except ImportError: # Python < 3.2 - - class timezone(tzinfo): - """Simple timezone implementation.""" - - def __init__(self, offset, name=None): - self.offset = offset - if not name: - minutes = self.offset.days * 1440 + self.offset.seconds // 60 - if minutes < 0: - hours, minutes = divmod(-minutes, 60) - hours = -hours - else: - hours, minutes = divmod(minutes, 60) - name = 'UTC%+03d:%02d' % (hours, minutes) - self.name = name - - def utcoffset(self, dt): - return self.offset +def get_args(func): + return list(signature(func).parameters) - def tzname(self, dt): - return self.name - - def dst(self, dt): - return None - - timezone.utc = timezone(timedelta(0), 'UTC') - - _has_timezone = False -else: - _has_timezone = True # time zones used in Postgres timestamptz output _timezones = dict(CET='+0100', EET='+0200', EST='-0500', @@ -259,14 +117,6 @@ def _timezone_as_offset(tz): return _timezones.get(tz, '+0000') -def _get_timezone(tz): - tz = _timezone_as_offset(tz) - minutes = 60 * int(tz[1:3]) + int(tz[3:5]) - if tz[0] == '-': - minutes = -minutes - return timezone(timedelta(minutes=minutes), tz) - - def _oid_key(table): """Build oid key from a table name.""" return 'oid(%s)' % table @@ -285,7 +135,7 @@ class Hstore(dict): def _quote(cls, s): if s is None: return 'NULL' - if not isinstance(s, basestring): + if not isinstance(s, str): s = str(s) if not s: return '""' @@ -308,7 +158,7 @@ def __init__(self, obj, encode=None): def __str__(self): obj = self.obj - if isinstance(obj, basestring): + if isinstance(obj, str): return obj return self.encode(obj) @@ -330,8 +180,7 @@ class _SimpleTypes(dict): 'int': ['cid', 'int2', 'int4', 'int8', 'oid', 'xid', int], 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], 'num': ['numeric', Decimal], 'money': [], - 'text': ['bpchar', 'char', 'name', 'varchar', - bytes, unicode, basestring] + 'text': ['bpchar', 'char', 'name', 'varchar', bytes, str] } # type: Dict[str, List[Union[str, type]]] # noinspection PyMissingConstructor @@ -369,7 +218,7 @@ def _quote_if_unqualified(param, name): (could be a qualified name or just a name with a dot in it) and must be quoted manually by the caller. """ - if isinstance(name, basestring) and '.' not in name: + if isinstance(name, str) and '.' not in name: return 'quote_ident(%s)' % (param,) return param @@ -440,7 +289,7 @@ def __init__(self, db): @classmethod def _adapt_bool(cls, v): """Adapt a boolean parameter.""" - if isinstance(v, basestring): + if isinstance(v, str): if not v: return None v = v.lower() in cls._bool_true_values @@ -451,7 +300,7 @@ def _adapt_date(cls, v): """Adapt a date parameter.""" if not v: return None - if isinstance(v, basestring) and v.lower() in cls._date_literals: + if isinstance(v, str) and v.lower() in cls._date_literals: return Literal(v) return v @@ -472,7 +321,7 @@ def _adapt_json(self, v): """Adapt a json parameter.""" if not v: return None - if isinstance(v, basestring): + if isinstance(v, str): return v if isinstance(v, Json): return str(v) @@ -482,7 +331,7 @@ def _adapt_hstore(self, v): """Adapt a hstore parameter.""" if not v: return None - if isinstance(v, basestring): + if isinstance(v, str): return v if isinstance(v, Hstore): return str(v) @@ -494,7 +343,7 @@ def _adapt_uuid(self, v): """Adapt a UUID parameter.""" if not v: return None - if isinstance(v, basestring): + if isinstance(v, str): return v return str(v) @@ -523,7 +372,7 @@ def _adapt_bool_array(cls, v): return '{%s}' % ','.join(adapt(v) for v in v) if v is None: return 'null' - if isinstance(v, basestring): + if isinstance(v, str): if not v: return 'null' v = v.lower() in cls._bool_true_values @@ -558,7 +407,7 @@ def _adapt_json_array(self, v): return '{%s}' % ','.join(adapt(v) for v in v) if not v: return 'null' - if not isinstance(v, basestring): + if not isinstance(v, str): v = self.db.encode_json(v) if self._re_array_quote.search(v): v = '"%s"' % self._re_array_escape.sub(r'\\\1', v) @@ -642,11 +491,11 @@ def guess_simple_type(cls, value): return _simple_type_dict[type(value)] except KeyError: pass - if isinstance(value, basestring): + if isinstance(value, (bytes, str)): return 'text' if isinstance(value, bool): return 'bool' - if isinstance(value, (int, long)): + if isinstance(value, int): return 'int' if isinstance(value, float): return 'float' @@ -695,12 +544,10 @@ def adapt_inline(self, value, nested=False): if isinstance(value, Literal): return value if isinstance(value, Bytea): - value = self.db.escape_bytea(value) - if bytes is not str: # Python >= 3.0 - value = value.decode('ascii') + value = self.db.escape_bytea(value).decode('ascii') elif isinstance(value, (datetime, date, time, timedelta)): value = str(value) - if isinstance(value, basestring): + if isinstance(value, (bytes, str)): value = self.db.escape_string(value) return "'%s'" % value if isinstance(value, bool): @@ -711,7 +558,7 @@ def adapt_inline(self, value, nested=False): if isnan(value): return "'NaN'" return value - if isinstance(value, (int, long, Decimal)): + if isinstance(value, (int, Decimal)): return value if isinstance(value, list): q = self.adapt_inline @@ -767,7 +614,7 @@ def format_query(self, command, values=None, types=None, inline=False): else: add = params.add if types: - if isinstance(types, basestring): + if isinstance(types, str): types = types.split() if (not isinstance(types, (list, tuple)) or len(types) != len(values)): @@ -884,12 +731,9 @@ def cast_timetz(value): else: tz = '+0000' fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - if _has_timezone: - value += _timezone_as_offset(tz) - fmt += '%z' - return datetime.strptime(value, fmt).timetz() - return datetime.strptime(value, fmt).timetz().replace( - tzinfo=_get_timezone(tz)) + value += _timezone_as_offset(tz) + fmt += '%z' + return datetime.strptime(value, fmt).timetz() def cast_timestamp(value, connection): @@ -944,12 +788,9 @@ def cast_timestamptz(value, connection): if len(value[0]) > 10: return datetime.max fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - if _has_timezone: - value.append(_timezone_as_offset(tz)) - fmt.append('%z') - return datetime.strptime(' '.join(value), ' '.join(fmt)) - return datetime.strptime(' '.join(value), ' '.join(fmt)).replace( - tzinfo=_get_timezone(tz)) + value.append(_timezone_as_offset(tz)) + fmt.append('%z') + return datetime.strptime(' '.join(value), ' '.join(fmt)) _re_interval_sql_standard = regex( @@ -1057,7 +898,7 @@ class Typecasts(dict): 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, - 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, + 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, 'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json, 'float4': float, 'float8': float, 'numeric': cast_num, 'money': cast_money, @@ -1117,7 +958,7 @@ def get(self, typ, default=None): def set(self, typ, cast): """Set a typecast function for the specified database type(s).""" - if isinstance(typ, basestring): + if isinstance(typ, str): typ = [typ] if cast is None: for t in typ: @@ -1138,7 +979,7 @@ def reset(self, typ=None): if typ is None: self.clear() else: - if isinstance(typ, basestring): + if isinstance(typ, str): typ = [typ] for t in typ: self.pop(t, None) @@ -1151,7 +992,7 @@ def get_default(cls, typ): @classmethod def set_default(cls, typ, cast): """Set a default typecast function for the given database type(s).""" - if isinstance(typ, basestring): + if isinstance(typ, str): typ = [typ] defaults = cls.defaults if cast is None: @@ -1716,7 +1557,7 @@ def _do_debug(self, *args): """Print a debug message""" if self.debug: s = '\n'.join(str(arg) for arg in args) - if isinstance(self.debug, basestring): + if isinstance(self.debug, str): print(self.debug % s) elif hasattr(self.debug, 'write'): # noinspection PyCallingNonCallable @@ -1858,7 +1699,7 @@ def get_parameter(self, parameter): By passing the special name 'all' as the parameter, you can get a dict of all existing configuration parameters. """ - if isinstance(parameter, basestring): + if isinstance(parameter, str): parameter = [parameter] values = None elif isinstance(parameter, (list, tuple)): @@ -1875,7 +1716,7 @@ def get_parameter(self, parameter): params = {} if isinstance(values, dict) else [] for key in parameter: param = key.strip().lower() if isinstance( - key, basestring) else None + key, (bytes, str)) else None if not param: raise TypeError('Invalid parameter') if param == 'all': @@ -1923,7 +1764,7 @@ def set_parameter(self, parameter, value=None, local=False): have no effect if it is executed outside a transaction, since the transaction will end immediately. """ - if isinstance(parameter, basestring): + if isinstance(parameter, str): parameter = {parameter: value} elif isinstance(parameter, (list, tuple)): if isinstance(value, (list, tuple)): @@ -1935,7 +1776,7 @@ def set_parameter(self, parameter, value=None, local=False): value = set(value) if len(value) == 1: value = value.pop() - if not (value is None or isinstance(value, basestring)): + if not (value is None or isinstance(value, str)): raise ValueError( 'A single value must be specified' ' when parameter is a set') @@ -1953,7 +1794,7 @@ def set_parameter(self, parameter, value=None, local=False): params = {} for key, value in parameter.items(): param = key.strip().lower() if isinstance( - key, basestring) else None + key, str) else None if not param: raise TypeError('Invalid parameter') if param == 'all': @@ -2272,7 +2113,7 @@ def get(self, table, row, keyname=None): table = table[:-1].rstrip() attnames = self.get_attnames(table) qoid = _oid_key(table) if 'oid' in attnames else None - if keyname and isinstance(keyname, basestring): + if keyname and isinstance(keyname, str): keyname = (keyname,) if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row: row['oid'] = row[qoid] @@ -2513,7 +2354,7 @@ def upsert(self, table, row=None, **kw): if n not in keyname and n not in generated: value = kw.get(n, n in row) if value: - if not isinstance(value, basestring): + if not isinstance(value, str): value = 'excluded.%s' % col(n) update.append('%s = %s' % (col(n), value)) if not values: @@ -2631,7 +2472,7 @@ def truncate(self, table, restart=False, cascade=False, only=False): can be specified after the table name to explicitly indicate that descendant tables are included. """ - if isinstance(table, basestring): + if isinstance(table, str): only = {table: only} table = [table] elif isinstance(table, (list, tuple)): @@ -2764,7 +2605,7 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, keyname = self.pkey(table, True) except (KeyError, ProgrammingError): raise _prg_error('Table %s has no primary key' % table) - if isinstance(keyname, basestring): + if isinstance(keyname, str): keyname = [keyname] elif not isinstance(keyname, (list, tuple)): raise KeyError('The keyname must be a string, list or tuple') diff --git a/pgdb.py b/pgdb.py index 2919caf3..85767e3a 100644 --- a/pgdb.py +++ b/pgdb.py @@ -90,10 +90,9 @@ else: libpq += 'so' if e: - # note: we could use "raise from e" here in Python 3 raise ImportError( "Cannot import shared library for PyGreSQL,\n" - "probably because no %s is installed.\n%s" % (libpq, e)) + "probably because no %s is installed.\n%s" % (libpq, e)) from e __version__ = version @@ -114,122 +113,20 @@ 'get_typecast', 'set_typecast', 'reset_typecast', 'version', '__version__'] -from datetime import date, time, datetime, timedelta, tzinfo +from datetime import date, time, datetime, timedelta from time import localtime from decimal import Decimal as StdDecimal from uuid import UUID as Uuid from math import isnan, isinf -try: # noinspection PyCompatibility - from collections.abc import Iterable -except ImportError: # Python < 3.3 - from collections import Iterable from collections import namedtuple -from functools import partial +from collections.abc import Iterable +from inspect import signature +from functools import lru_cache, partial from re import compile as regex from json import loads as jsondecode, dumps as jsonencode Decimal = StdDecimal -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - unicode -except NameError: # Python >= 3.0 - unicode = str - -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - basestring -except NameError: # Python >= 3.0 - basestring = (str, bytes) - -try: - from functools import lru_cache -except ImportError: # Python < 3.2 - from functools import update_wrapper - try: # noinspection PyCompatibility - from _thread import RLock - except ImportError: - class RLock: # for builds without threads - def __enter__(self): - pass - - def __exit__(self, exctype, excinst, exctb): - pass - - def lru_cache(maxsize=128): - """Simplified functools.lru_cache decorator for one argument.""" - - def decorator(function): - sentinel = object() - cache = {} - get = cache.get - lock = RLock() - root = [] - root_full = [root, False] - root[:] = [root, root, None, None] - - if maxsize == 0: - - def wrapper(arg): - res = function(arg) - return res - - elif maxsize is None: - - def wrapper(arg): - res = get(arg, sentinel) - if res is not sentinel: - return res - res = function(arg) - cache[arg] = res - return res - - else: - - def wrapper(arg): - with lock: - link = get(arg) - if link is not None: - root = root_full[0] - prv, nxt, _arg, res = link - prv[1] = nxt - nxt[0] = prv - last = root[0] - last[1] = root[0] = link - link[0] = last - link[1] = root - return res - res = function(arg) - with lock: - root, full = root_full - if arg in cache: - pass - elif full: - oldroot = root - oldroot[2] = arg - oldroot[3] = res - root = root_full[0] = oldroot[1] - oldarg = root[2] - oldres = root[3] # noqa F481 (keep reference) - root[2] = root[3] = None - del cache[oldarg] - cache[arg] = oldroot - else: - last = root[0] - link = [last, root, arg, res] - last[1] = root[0] = cache[arg] = link - if len(cache) >= maxsize: - root_full[1] = True - return res - - wrapper.__wrapped__ = function - return update_wrapper(wrapper, function) - - return decorator - # *** Module Constants *** @@ -249,51 +146,9 @@ def wrapper(arg): # *** Internal Type Handling *** -try: # noinspection PyUnresolvedReferences - from inspect import signature -except ImportError: # Python < 3.3 - from inspect import getargspec +def get_args(func): + return list(signature(func).parameters) - def get_args(func): - return getargspec(func).args -else: - - def get_args(func): - return list(signature(func).parameters) - -try: - from datetime import timezone -except ImportError: # Python < 3.2 - - class timezone(tzinfo): - """Simple timezone implementation.""" - - def __init__(self, offset, name=None): - self.offset = offset - if not name: - minutes = self.offset.days * 1440 + self.offset.seconds // 60 - if minutes < 0: - hours, minutes = divmod(-minutes, 60) - hours = -hours - else: - hours, minutes = divmod(minutes, 60) - name = 'UTC%+03d:%02d' % (hours, minutes) - self.name = name - - def utcoffset(self, dt): - return self.offset - - def tzname(self, dt): - return self.name - - def dst(self, dt): - return None - - timezone.utc = timezone(timedelta(0), 'UTC') - - _has_timezone = False -else: - _has_timezone = True # time zones used in Postgres timestamptz output _timezones = dict(CET='+0100', EET='+0200', EST='-0500', @@ -309,14 +164,6 @@ def _timezone_as_offset(tz): return _timezones.get(tz, '+0000') -def _get_timezone(tz): - tz = _timezone_as_offset(tz) - minutes = 60 * int(tz[1:3]) + int(tz[3:5]) - if tz[0] == '-': - minutes = -minutes - return timezone(timedelta(minutes=minutes), tz) - - def decimal_type(decimal_type=None): """Get or set global type to be used for decimal values. @@ -385,12 +232,9 @@ def cast_timetz(value): else: tz = '+0000' fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - if _has_timezone: - value += _timezone_as_offset(tz) - fmt += '%z' - return datetime.strptime(value, fmt).timetz() - return datetime.strptime(value, fmt).timetz().replace( - tzinfo=_get_timezone(tz)) + value += _timezone_as_offset(tz) + fmt += '%z' + return datetime.strptime(value, fmt).timetz() def cast_timestamp(value, connection): @@ -445,12 +289,9 @@ def cast_timestamptz(value, connection): if len(value[0]) > 10: return datetime.max fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - if _has_timezone: - value.append(_timezone_as_offset(tz)) - fmt.append('%z') - return datetime.strptime(' '.join(value), ' '.join(fmt)) - return datetime.strptime(' '.join(value), ' '.join(fmt)).replace( - tzinfo=_get_timezone(tz)) + value.append(_timezone_as_offset(tz)) + fmt.append('%z') + return datetime.strptime(' '.join(value), ' '.join(fmt)) _re_interval_sql_standard = regex( @@ -555,7 +396,7 @@ class Typecasts(dict): 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, - 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, + 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, 'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode, 'float4': float, 'float8': float, 'numeric': Decimal, 'money': cast_money, @@ -611,7 +452,7 @@ def get(self, typ, default=None): def set(self, typ, cast): """Set a typecast function for the specified database type(s).""" - if isinstance(typ, basestring): + if isinstance(typ, str): typ = [typ] if cast is None: for t in typ: @@ -634,7 +475,7 @@ def reset(self, typ=None): self.clear() self.update(defaults) else: - if isinstance(typ, basestring): + if isinstance(typ, str): typ = [typ] for t in typ: cast = defaults.get(t) @@ -967,11 +808,9 @@ def _quote(self, value): return 'NULL' if isinstance(value, (Hstore, Json)): value = str(value) - if isinstance(value, basestring): + if isinstance(value, (bytes, str)): if isinstance(value, Binary): - value = self._cnx.escape_bytea(value) - if bytes is not str: # Python >= 3.0 - value = value.decode('ascii') + value = self._cnx.escape_bytea(value).decode('ascii') else: value = self._cnx.escape_string(value) return "'%s'" % (value,) @@ -981,7 +820,7 @@ def _quote(self, value): if isnan(value): return "'NaN'" return value - if isinstance(value, (int, long, Decimal, Literal)): + if isinstance(value, (int, Decimal, Literal)): return value if isinstance(value, datetime): if value.tzinfo: @@ -1237,10 +1076,10 @@ def copy_from(self, stream, table, input_type = bytes type_name = 'byte strings' else: - input_type = basestring + input_type = (bytes, str) type_name = 'strings' - if isinstance(stream, basestring): + if isinstance(stream, (bytes, str)): if not isinstance(stream, input_type): raise ValueError("The input must be %s" % (type_name,)) if not binary_format: @@ -1291,7 +1130,7 @@ def chunks(): def chunks(): yield read() - if not table or not isinstance(table, basestring): + if not table or not isinstance(table, str): raise TypeError("Need a table to copy to") if table.lower().startswith('select '): raise ValueError("Must specify a table, not a query") @@ -1302,13 +1141,13 @@ def chunks(): options = [] params = [] if format is not None: - if not isinstance(format, basestring): + if not isinstance(format, str): raise TypeError("The format option must be be a string") if format not in ('text', 'csv', 'binary'): raise ValueError("Invalid format") options.append('format %s' % (format,)) if sep is not None: - if not isinstance(sep, basestring): + if not isinstance(sep, str): raise TypeError("The sep option must be a string") if format == 'binary': raise ValueError( @@ -1319,12 +1158,12 @@ def chunks(): options.append('delimiter %s') params.append(sep) if null is not None: - if not isinstance(null, basestring): + if not isinstance(null, str): raise TypeError("The null option must be a string") options.append('null %s') params.append(null) if columns: - if not isinstance(columns, basestring): + if not isinstance(columns, str): columns = ','.join(map( self.connection._cnx.escape_identifier, columns)) operation.append('(%s)' % (columns,)) @@ -1375,7 +1214,7 @@ def copy_to(self, stream, table, write = stream.write except AttributeError: raise TypeError("Need an output stream to copy to") - if not table or not isinstance(table, basestring): + if not table or not isinstance(table, str): raise TypeError("Need a table to copy to") if table.lower().startswith('select '): if columns: @@ -1388,13 +1227,13 @@ def copy_to(self, stream, table, options = [] params = [] if format is not None: - if not isinstance(format, basestring): + if not isinstance(format, str): raise TypeError("The format option must be a string") if format not in ('text', 'csv', 'binary'): raise ValueError("Invalid format") options.append('format %s' % (format,)) if sep is not None: - if not isinstance(sep, basestring): + if not isinstance(sep, str): raise TypeError("The sep option must be a string") if binary_format: raise ValueError( @@ -1405,15 +1244,12 @@ def copy_to(self, stream, table, options.append('delimiter %s') params.append(sep) if null is not None: - if not isinstance(null, basestring): + if not isinstance(null, str): raise TypeError("The null option must be a string") options.append('null %s') params.append(null) if decode is None: - if format == 'binary': - decode = False - else: - decode = str is unicode + decode = format != 'binary' else: if not isinstance(decode, (int, bool)): raise TypeError("The decode option must be a boolean") @@ -1421,7 +1257,7 @@ def copy_to(self, stream, table, raise ValueError( "The decode option is not allowed with binary format") if columns: - if not isinstance(columns, basestring): + if not isinstance(columns, str): columns = ','.join(map( self.connection._cnx.escape_identifier, columns)) operation.append('(%s)' % (columns,)) @@ -1730,12 +1566,12 @@ class Type(frozenset): """ def __new__(cls, values): - if isinstance(values, basestring): + if isinstance(values, str): values = values.split() return super(Type, cls).__new__(cls, values) def __eq__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): if other.startswith('_'): other = other[1:] return other in self @@ -1743,7 +1579,7 @@ def __eq__(self, other): return super(Type, self).__eq__(other) def __ne__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): if other.startswith('_'): other = other[1:] return other not in self @@ -1755,13 +1591,13 @@ class ArrayType: """Type class for PostgreSQL array types.""" def __eq__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): return other.startswith('_') else: return isinstance(other, ArrayType) def __ne__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): return not other.startswith('_') else: return not isinstance(other, ArrayType) @@ -1774,7 +1610,7 @@ def __eq__(self, other): if isinstance(other, TypeCode): # noinspection PyUnresolvedReferences return other.type == 'c' - elif isinstance(other, basestring): + elif isinstance(other, str): return other == 'record' else: return isinstance(other, RecordType) @@ -1783,7 +1619,7 @@ def __ne__(self, other): if isinstance(other, TypeCode): # noinspection PyUnresolvedReferences return other.type != 'c' - elif isinstance(other, basestring): + elif isinstance(other, str): return other != 'record' else: return not isinstance(other, RecordType) @@ -1884,7 +1720,7 @@ class Hstore(dict): def _quote(cls, s): if s is None: return 'NULL' - if not isinstance(s, basestring): + if not isinstance(s, str): s = str(s) if not s: return '""' @@ -1908,7 +1744,7 @@ def __init__(self, obj, encode=None): def __str__(self): obj = self.obj - if isinstance(obj, basestring): + if isinstance(obj, str): return obj return self.encode(obj) From b1fcd3b61d053988aab4038d2723fb9aa8a747da Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 20:47:15 +0200 Subject: [PATCH 125/194] Simplify test modules assuming modern Python --- docs/contents/pgdb/cursor.rst | 2 +- tests/dbapi20.py | 28 ++---- tests/test_classic_connection.py | 143 +++++++++---------------------- tests/test_classic_dbwrapper.py | 72 +++++----------- tests/test_classic_functions.py | 18 +--- tests/test_dbapi20.py | 38 +++----- tests/test_dbapi20_copy.py | 17 ++-- 7 files changed, 87 insertions(+), 231 deletions(-) diff --git a/docs/contents/pgdb/cursor.rst b/docs/contents/pgdb/cursor.rst index 52d600e8..e1ed8b0f 100644 --- a/docs/contents/pgdb/cursor.rst +++ b/docs/contents/pgdb/cursor.rst @@ -295,7 +295,7 @@ specified, all of them will be copied. :param str null: the textual representation of the ``NULL`` value, can also be an empty string (the default is ``'\\N'``) :param bool decode: whether decoded strings shall be returned - for non-binary formats (the default is True in Python 3) + for non-binary formats (the default is ``True``) :param list column: an optional list of column names :returns: a generator if stream is set to ``None``, otherwise the cursor diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 2bb7e2b0..b793fbf2 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -1,4 +1,5 @@ #!/usr/bin/python + """Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. @@ -9,23 +10,6 @@ import unittest import time -try: # noinspection PyUnresolvedReferences - _BaseException = StandardError # noqa: F821 -except NameError: # Python >= 3.0 - _BaseException = Exception - -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - - -def str2bytes(sval): - if str is not unicode and isinstance(sval, str): - # noinspection PyUnresolvedReferences - sval = sval.decode("latin1") - return sval.encode("latin1") # python 3 make unicode into bytes - class DatabaseAPI20Test(unittest.TestCase): """Test a database self.driver for DB API 2.0 compatibility. @@ -103,7 +87,7 @@ def tearDown(self): pass finally: con.close() - except _BaseException: + except Exception: pass def _connect(self): @@ -151,8 +135,8 @@ def test_Exceptions(self): # Make sure required exceptions exist, and are in the # defined hierarchy. sub = issubclass - self.assertTrue(sub(self.driver.Warning, _BaseException)) - self.assertTrue(sub(self.driver.Error, _BaseException)) + self.assertTrue(sub(self.driver.Warning, Exception)) + self.assertTrue(sub(self.driver.Error, Exception)) self.assertTrue(sub(self.driver.InterfaceError, self.driver.Error)) self.assertTrue(sub(self.driver.DatabaseError, self.driver.Error)) @@ -805,8 +789,8 @@ def test_Timestamp(self): self.assertEqual(str(t1), str(t2)) def test_Binary(self): - self.driver.Binary(str2bytes('Something')) - self.driver.Binary(str2bytes('')) + self.driver.Binary(b'Something') + self.driver.Binary(b'') def test_STRING(self): self.assertTrue(hasattr(self.driver, 'STRING'), diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 8c7adc39..0dedc5c4 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -15,31 +15,13 @@ import os from collections import namedtuple - -try: - # noinspection PyCompatibility - from collections.abc import Iterable -except ImportError: # Python < 3.3 - from collections import Iterable - +from collections.abc import Iterable from decimal import Decimal import pg # the module under test from .config import dbname, dbhost, dbport, dbuser, dbpasswd -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - -unicode_strings = str is not bytes - windows = os.name == 'nt' # There is a known a bug in libpq under Windows which can cause @@ -462,10 +444,10 @@ def testGetresult(self): def testGetresultLong(self): q = "select 9876543210" - result = long(9876543210) - self.assertIsInstance(result, long) + result = 9876543210 + self.assertIsInstance(result, int) v = self.c.query(q).getresult()[0][0] - self.assertIsInstance(v, long) + self.assertIsInstance(v, int) self.assertEqual(v, result) def testGetresultDecimal(self): @@ -506,10 +488,10 @@ def testDictresult(self): def testDictresultLong(self): q = "select 9876543210 as longjohnsilver" - result = long(9876543210) - self.assertIsInstance(result, long) + result = 9876543210 + self.assertIsInstance(result, int) v = self.c.query(q).dictresult()[0]['longjohnsilver'] - self.assertIsInstance(v, long) + self.assertIsInstance(v, int) self.assertEqual(v, result) def testDictresultDecimal(self): @@ -839,7 +821,7 @@ def testMemSize(self): query = self.c.query q = query("select repeat('foo!', 8)") size = q.memsize() - self.assertIsInstance(size, long) + self.assertIsInstance(size, int) self.assertGreaterEqual(size, 32) self.assertLess(size, 8000) q = query("select repeat('foo!', 2000)") @@ -875,8 +857,6 @@ def testDictresulAscii(self): def testGetresultUtf8(self): result = u'Hello, wörld & мир!' q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('utf8') # pass the query as unicode try: v = self.c.query(q).getresult()[0][0] @@ -894,8 +874,6 @@ def testGetresultUtf8(self): def testDictresultUtf8(self): result = u'Hello, wörld & мир!' q = u"select '%s' as greeting" % result - if not unicode_strings: - result = result.encode('utf8') try: v = self.c.query(q).dictresult()[0]['greeting'] except (pg.DataError, pg.NotSupportedError): @@ -915,8 +893,6 @@ def testGetresultLatin1(self): self.skipTest("database does not support latin1") result = u'Hello, wörld!' q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('latin1') v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -932,8 +908,6 @@ def testDictresultLatin1(self): self.skipTest("database does not support latin1") result = u'Hello, wörld!' q = u"select '%s' as greeting" % result - if not unicode_strings: - result = result.encode('latin1') v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -949,8 +923,6 @@ def testGetresultCyrillic(self): self.skipTest("database does not support cyrillic") result = u'Hello, мир!' q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('cyrillic') v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -966,8 +938,6 @@ def testDictresultCyrillic(self): self.skipTest("database does not support cyrillic") result = u'Hello, мир!' q = u"select '%s' as greeting" % result - if not unicode_strings: - result = result.encode('cyrillic') v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -983,8 +953,6 @@ def testGetresultLatin9(self): self.skipTest("database does not support latin9") result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('latin9') v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -1000,8 +968,6 @@ def testDictresultLatin9(self): self.skipTest("database does not support latin9") result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' q = u"select '%s' as menu" % result - if not unicode_strings: - result = result.encode('latin9') v = self.c.query(q).dictresult()[0]['menu'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -1138,20 +1104,14 @@ def testQueryWithUnicodeParamsLatin1(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() - if unicode_strings: - self.assertEqual(r, [('Hello, wörld!',)]) - else: - self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)]) + self.assertEqual(r, [('Hello, wörld!',)]) self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', u'мир')) query('set client_encoding=iso_8859_1') r = query( "select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() - if unicode_strings: - self.assertEqual(r, [('Hello, wörld!',)]) - else: - self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)]) + self.assertEqual(r, [('Hello, wörld!',)]) self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', u'мир')) @@ -1173,10 +1133,7 @@ def testQueryWithUnicodeParamsCyrillic(self): ('Hello', u'wörld')) r = query( "select $1||', '||$2||'!'", ('Hello', u'мир')).getresult() - if unicode_strings: - self.assertEqual(r, [('Hello, мир!',)]) - else: - self.assertEqual(r, [(u'Hello, мир!'.encode('cyrillic'),)]) + self.assertEqual(r, [('Hello, мир!',)]) query('set client_encoding=sql_ascii') self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", @@ -1337,7 +1294,7 @@ def testInt(self): self.assert_proper_cast(0, 'xid', int) def testLong(self): - self.assert_proper_cast(0, 'bigint', long) + self.assert_proper_cast(0, 'bigint', int) def testFloat(self): self.assert_proper_cast(0, 'float', float) @@ -1806,22 +1763,22 @@ def tearDown(self): self.c.close() data = [ - (-1, -1, long(-1), True, '1492-10-12', '08:30:00', + (-1, -1, -1, True, '1492-10-12', '08:30:00', -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), - (0, 0, long(0), False, '1607-04-14', '09:00:00', + (0, 0, 0, False, '1607-04-14', '09:00:00', 0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'), - (1, 1, long(1), True, '1801-03-04', '03:45:00', + (1, 1, 1, True, '1801-03-04', '03:45:00', 1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'), - (2, 2, long(2), False, '1903-12-17', '11:22:00', + (2, 2, 2, False, '1903-12-17', '11:22:00', 2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')] @classmethod def db_len(cls, s, encoding): # noinspection PyUnresolvedReferences if cls.has_encoding: - s = s if isinstance(s, unicode) else s.decode(encoding) + s = s if isinstance(s, str) else s.decode(encoding) else: - s = s.encode(encoding) if isinstance(s, unicode) else s + s = s.encode(encoding) if isinstance(s, str) else s return len(s) def get_back(self, encoding='utf-8'): @@ -1835,7 +1792,7 @@ def get_back(self, encoding='utf-8'): if row[1] is not None: # integer self.assertIsInstance(row[1], int) if row[2] is not None: # bigint - self.assertIsInstance(row[2], long) + self.assertIsInstance(row[2], int) if row[3] is not None: # boolean self.assertIsInstance(row[3], bool) if row[4] is not None: # date @@ -2039,7 +1996,7 @@ def testInserttableWithOutOfRangeData(self): ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) def testInserttableMaxValues(self): - data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), + data = [(2 ** 15 - 1, 2 ** 31 - 1, 2 ** 31 - 1, True, '2999-12-31', '11:59:59', 1e99, 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, "1", "1234", "1234", "1234" * 100)] @@ -2054,16 +2011,15 @@ def testInserttableByteValues(self): # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' row_unicode = ( - 0, 0, long(0), False, u'1970-01-01', u'00:00:00', + 0, 0, 0, False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") row_bytes = tuple( - s.encode('utf-8') if isinstance(s, unicode) else s + s.encode('utf-8') if isinstance(s, str) else s for s in row_unicode) data = [row_bytes] * 2 self.c.inserttable('test', data) - if unicode_strings: - data = [row_unicode] * 2 + data = [row_unicode] * 2 self.assertEqual(self.get_back(), data) def testInserttableUnicodeUtf8(self): @@ -2074,16 +2030,11 @@ def testInserttableUnicodeUtf8(self): # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' row_unicode = ( - 0, 0, long(0), False, u'1970-01-01', u'00:00:00', + 0, 0, 0, False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") data = [row_unicode] * 2 self.c.inserttable('test', data) - if not unicode_strings: - row_bytes = tuple( - s.encode('utf-8') if isinstance(s, unicode) else s - for s in row_unicode) - data = [row_bytes] * 2 self.assertEqual(self.get_back(), data) def testInserttableUnicodeLatin1(self): @@ -2095,22 +2046,17 @@ def testInserttableUnicodeLatin1(self): # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' row_unicode = ( - 0, 0, long(0), False, u'1970-01-01', u'00:00:00', + 0, 0, 0, False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") data = [row_unicode] # cannot encode € sign with latin1 encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) row_unicode = tuple( - s.replace(u'€', u'¥') if isinstance(s, unicode) else s + s.replace(u'€', u'¥') if isinstance(s, str) else s for s in row_unicode) data = [row_unicode] * 2 self.c.inserttable('test', data) - if not unicode_strings: - row_bytes = tuple( - s.encode('latin1') if isinstance(s, unicode) else s - for s in row_unicode) - data = [row_bytes] * 2 self.assertEqual(self.get_back('latin1'), data) def testInserttableUnicodeLatin9(self): @@ -2123,16 +2069,11 @@ def testInserttableUnicodeLatin9(self): # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' row_unicode = ( - 0, 0, long(0), False, u'1970-01-01', u'00:00:00', + 0, 0, 0, False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") data = [row_unicode] * 2 self.c.inserttable('test', data) - if not unicode_strings: - row_bytes = tuple( - s.encode('latin9') if isinstance(s, unicode) else s - for s in row_unicode) - data = [row_bytes] * 2 self.assertEqual(self.get_back('latin9'), data) def testInserttableNoEncoding(self): @@ -2140,7 +2081,7 @@ def testInserttableNoEncoding(self): # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' row_unicode = ( - 0, 0, long(0), False, u'1970-01-01', u'00:00:00', + 0, 0, 0, False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") data = [row_unicode] @@ -2164,7 +2105,7 @@ def __repr__(self): return s s = '1\'2"3\b4\f5\n6\r7\t8\b9\\0' - s1 = s.encode('ascii') if unicode_strings else s.decode('ascii') + s1 = s.encode('ascii') s2 = S() data = [(t,) for t in (s, s1, s2)] self.c.inserttable('test', data, ['t']) @@ -2596,7 +2537,7 @@ def testSetDecimal(self): pg.set_decimal(decimal_class) self.assertNotIsInstance(r, decimal_class) self.assertIsInstance(r, int) - self.assertEqual(r, int(3425)) + self.assertEqual(r, 3425) def testGetBool(self): use_bool = pg.get_bool() @@ -2725,10 +2666,7 @@ def testSetByteaEscaped(self): self.assertEqual(r, b'data') def testSetRowFactorySize(self): - try: - from functools import lru_cache - except ImportError: # Python < 3.2 - lru_cache = None + from functools import lru_cache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): @@ -2742,12 +2680,11 @@ def testSetRowFactorySize(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - if lru_cache: - info = pg._row_factory.cache_info() - self.assertEqual(info.maxsize, maxsize) - self.assertEqual(info.hits + info.misses, 6) - self.assertEqual( - info.hits, 0 if maxsize is not None and maxsize < 2 else 4) + info = pg._row_factory.cache_info() + self.assertEqual(info.maxsize, maxsize) + self.assertEqual(info.hits + info.misses, 6) + self.assertEqual( + info.hits, 0 if maxsize is not None and maxsize < 2 else 4) class TestStandaloneEscapeFunctions(unittest.TestCase): @@ -2783,13 +2720,13 @@ def testEscapeString(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(u'plain') - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'plain') r = f(u"das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, u"das is'' käse".encode('utf-8')) r = f(u"that's cheesy") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"that''s cheesy") r = f(r"It's bad to have a \ inside.") self.assertEqual(r, r"It''s bad to have a \\ inside.") @@ -2801,13 +2738,13 @@ def testEscapeBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(u'plain') - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'plain') r = f(u"das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, b"das is'' k\\\\303\\\\244se") r = f(u"that's cheesy") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"that''s cheesy") r = f(b'O\x00ps\xff!') self.assertEqual(r, b'O\\\\000ps\\\\377!') diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 0843710d..25c3c11d 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -21,6 +21,7 @@ from collections import OrderedDict from decimal import Decimal from datetime import date, time, datetime, timedelta +from io import StringIO from uuid import UUID from time import strftime from operator import itemgetter @@ -29,21 +30,6 @@ debug = False # let DB wrapper print debugging output -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - -if str is bytes: # noinspection PyCompatibility,PyUnresolvedReferences - from StringIO import StringIO -else: # Python >= 3.0 - from io import StringIO - windows = os.name == 'nt' # There is a known a bug in libpq under Windows which can cause @@ -523,13 +509,13 @@ def testEscapeLiteral(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b"'plain'") r = f(u"plain") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"'plain'") r = f(u"that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, u"'that''s käse'".encode('utf-8')) r = f(u"that's käse") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"'that''s käse'") self.assertEqual(f(r"It's fine to have a \ inside."), r" E'It''s fine to have a \\ inside.'") @@ -542,13 +528,13 @@ def testEscapeIdentifier(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'"plain"') r = f(u"plain") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'"plain"') r = f(u"that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, u'"that\'s käse"'.encode('utf-8')) r = f(u"that's käse") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'"that\'s käse"') self.assertEqual(f(r"It's fine to have a \ inside."), '"It\'s fine to have a \\ inside."') @@ -561,13 +547,13 @@ def testEscapeString(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b"plain") r = f(u"plain") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"plain") r = f(u"that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, u"that''s käse".encode('utf-8')) r = f(u"that's käse") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"that''s käse") self.assertEqual(f(r"It's fine to have a \ inside."), r"It''s fine to have a \ inside.") @@ -580,13 +566,13 @@ def testEscapeBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x706c61696e') r = f(u'plain') - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'\\x706c61696e') r = f(u"das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x64617320697327206bc3a47365') r = f(u"das is' käse") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'\\x64617320697327206bc3a47365') self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21') @@ -623,7 +609,7 @@ def testDecodeJson(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -1783,8 +1769,7 @@ def testInsert(self): (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)), (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)), dict(i2=42, i4=123456, i8=9876543210), - dict(i2=2 ** 15 - 1, - i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)), + dict(i2=2 ** 15 - 1, i4=2 ** 31 - 1, i8=2 ** 63 - 1), dict(d=None), (dict(d=''), dict(d=None)), dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))), dict(f4=None, f8=None), dict(f4=0, f8=0), @@ -2519,9 +2504,9 @@ def testClear(self): r['a'] = r['f'] = r['n'] = 1 r['d'] = r['t'] = 'x' r['b'] = 't' - r['oid'] = long(1) + r['oid'] = 1 r = clear(table, r) - result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1)) + result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=1) self.assertEqual(r, result) def testClearWithQuotedNames(self): @@ -3455,7 +3440,7 @@ def testInsertGetJson(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3472,7 +3457,7 @@ def testInsertGetJson(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3525,7 +3510,7 @@ def testInsertGetJsonb(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3542,7 +3527,7 @@ def testInsertGetJsonb(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3578,8 +3563,7 @@ def testArray(self): data = dict( id=42, i2=[42, 1234, None, 0, -1], i4=[42, 123456789, None, 0, 1, -1], - i8=[long(42), long(123456789123456789), None, - long(0), long(1), long(-1)], + i8=[42, 123456789123456789, None, 0, 1, -1], d=[decimal(42), long_decimal, None, decimal(0), decimal(1), decimal(-1), -long_decimal], f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0, @@ -4053,10 +4037,7 @@ def testTimetz(self): timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): tz = '%+03d00' % timezones[timezone] - try: - tzinfo = datetime.strptime(tz, '%z').tzinfo - except ValueError: # Python < 3.2 - tzinfo = pg._get_timezone(tz) + tzinfo = datetime.strptime(tz, '%z').tzinfo self.db.set_parameter('timezone', timezone) d = time(15, 9, 26, tzinfo=tzinfo) q = "select $1::timetz" @@ -4108,10 +4089,7 @@ def testTimestamptz(self): timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): tz = '%+03d00' % timezones[timezone] - try: - tzinfo = datetime.strptime(tz, '%z').tzinfo - except ValueError: # Python < 3.2 - tzinfo = pg._get_timezone(tz) + tzinfo = datetime.strptime(tz, '%z').tzinfo self.db.set_parameter('timezone', timezone) for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', 'SQL, MDY', 'SQL, DMY', 'German'): @@ -4546,20 +4524,14 @@ def testAdaptQueryTypedWithHstore(self): value = {'one': "it's fine", 'two': 2} sql, params = format_query("select %s", (value,), 'hstore') self.assertEqual(sql, "select $1") - if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict - params[0] = ','.join(sorted(params[0].split(','))) self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", (value,), 'hstore') self.assertEqual(sql, "select $1") - if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict - params[0] = ','.join(sorted(params[0].split(','))) self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", [value], [pg.Hstore]) self.assertEqual(sql, "select $1") - if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict - params[0] = ','.join(sorted(params[0].split(','))) self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) def testAdaptQueryTypedWithUuid(self): @@ -4658,8 +4630,6 @@ def testAdaptQueryUntypedWithHstore(self): value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", (value,)) self.assertEqual(sql, "select $1") - if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict - params[0] = ','.join(sorted(params[0].split(','))) self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) def testAdaptQueryUntypedDict(self): @@ -4729,8 +4699,6 @@ def testAdaptQueryInlineListWithHstore(self): format_query = self.adapter.format_query value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", (value,), inline=True) - if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict - sql = sql[:8] + ','.join(sorted(sql[8:-9].split(','))) + sql[-9:] self.assertEqual( sql, "select 'one=>\"it''s fine\",two=>2'::hstore") self.assertEqual(params, []) diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index db450ec8..282ec6df 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -18,16 +18,6 @@ from datetime import timedelta -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - class TestHasConnect(unittest.TestCase): """Test existence of basic pg module functions.""" @@ -123,8 +113,8 @@ def testDefBase(self): def testPqlibVersion(self): # noinspection PyUnresolvedReferences v = pg.get_pqlib_version() - self.assertIsInstance(v, long) - self.assertGreater(v, 90000) + self.assertIsInstance(v, int) + self.assertGreater(v, 100000) self.assertLess(v, 160000) @@ -881,7 +871,7 @@ def testEscapeString(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(u'plain') - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'plain') r = f("that's cheese") self.assertIsInstance(r, str) @@ -893,7 +883,7 @@ def testEscapeBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(u'plain') - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'plain') r = f("that's cheese") self.assertIsInstance(r, str) diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 2d853f73..8505e518 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -4,7 +4,7 @@ import sys import unittest -from datetime import date, time, datetime, timedelta +from datetime import date, time, datetime, timedelta, timezone from uuid import UUID as Uuid import pgdb @@ -17,11 +17,6 @@ from .config import dbname, dbhost, dbport, dbuser, dbpasswd -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - class PgBitString: """Test object with a PostgreSQL representation as Bit String.""" @@ -492,7 +487,7 @@ def test_fetch_2_rows(self): self.assertIsInstance(row0[1], bytes) self.assertIsInstance(row0[2], bool) self.assertIsInstance(row0[3], int) - self.assertIsInstance(row0[4], long) + self.assertIsInstance(row0[4], int) self.assertIsInstance(row0[5], float) self.assertIsInstance(row0[6], Decimal) self.assertIsInstance(row0[7], Decimal) @@ -600,8 +595,8 @@ def test_datetime(self): "tz timetz, tsz timestamptz)" % table) for n in range(3): values = [dt.date(), dt.time(), dt, dt.time(), dt] - values[3] = values[3].replace(tzinfo=pgdb.timezone.utc) - values[4] = values[4].replace(tzinfo=pgdb.timezone.utc) + values[3] = values[3].replace(tzinfo=timezone.utc) + values[4] = values[4].replace(tzinfo=timezone.utc) if n == 0: # input as objects params = values if n == 1: # input as text @@ -609,7 +604,7 @@ def test_datetime(self): elif n == 2: # input using type helpers d = (dt.year, dt.month, dt.day) t = (dt.hour, dt.minute, dt.second, dt.microsecond) - z = (pgdb.timezone.utc,) + z = (timezone.utc,) params = [pgdb.Date(*d), pgdb.Time(*t), pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), pgdb.Timestamp(*(d + t + z))] @@ -1000,8 +995,6 @@ def test_unicode_with_utf8(self): output4 = cur.fetchone()[0] finally: con.close() - if str is bytes: # Python < 3.0 - s = s.encode('utf8') self.assertIsInstance(output1, str) self.assertEqual(output1, s) self.assertIsInstance(output2, str) @@ -1033,8 +1026,6 @@ def test_unicode_with_latin1(self): output4 = cur.fetchone()[0] finally: con.close() - if str is bytes: # Python < 3.0 - s = s.encode('latin1') self.assertIsInstance(output1, str) self.assertEqual(output1, s) self.assertIsInstance(output2, str) @@ -1347,10 +1338,7 @@ def test_no_close(self): self.assertEqual(row, data) def test_set_row_factory_size(self): - try: - from functools import lru_cache - except ImportError: # Python < 3.2 - lru_cache = None + from functools import lru_cache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] con = self._connect() cur = con.cursor() @@ -1366,12 +1354,11 @@ def test_set_row_factory_size(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - if lru_cache: - info = pgdb._row_factory.cache_info() - self.assertEqual(info.maxsize, maxsize) - self.assertEqual(info.hits + info.misses, 6) - self.assertEqual( - info.hits, 0 if maxsize is not None and maxsize < 2 else 4) + info = pgdb._row_factory.cache_info() + self.assertEqual(info.maxsize, maxsize) + self.assertEqual(info.hits + info.misses, 6) + self.assertEqual( + info.hits, 0 if maxsize is not None and maxsize < 2 else 4) def test_memory_leaks(self): ids = set() @@ -1384,9 +1371,6 @@ def test_memory_leaks(self): gc.collect() objs[:] = gc.get_objects() objs[:] = [obj for obj in objs if id(obj) not in ids] - if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)): - # workaround for Python issue 26811 - objs[:] = [obj for obj in objs if repr(obj) != '(,)'] self.assertEqual(len(objs), 0) def test_cve_2018_1058(self): diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index d6fd1cfc..d8661251 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -11,10 +11,7 @@ import unittest -try: # noinspection PyCompatibility - from collections.abc import Iterable -except ImportError: # Python < 3.3 - from collections import Iterable +from collections.abc import Iterable import pgdb # the module under test @@ -29,15 +26,13 @@ class InputStream: def __init__(self, data): - if isinstance(data, unicode): + if isinstance(data, str): data = data.encode('utf-8') self.data = data or b'' self.sizes = [] def __str__(self): - data = self.data - if str is unicode: # Python >= 3.0 - data = data.decode('utf-8') + data = self.data.decode('utf-8') return data def __len__(self): @@ -60,16 +55,14 @@ def __init__(self): self.sizes = [] def __str__(self): - data = self.data - if str is unicode: # Python >= 3.0 - data = data.decode('utf-8') + data = self.data.decode('utf-8') return data def __len__(self): return len(self.data) def write(self, data): - if isinstance(data, unicode): + if isinstance(data, str): data = data.encode('utf-8') self.data += data self.sizes.append(len(data)) From 1f66d19e49acc36e63a0170ff0161e3f7ab3b4ea Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 20:47:39 +0200 Subject: [PATCH 126/194] Mention new version in README file --- .bumpversion.cfg | 2 +- README.rst | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 3a654eda..89aec55e 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 5.2.5 +current_version = 6.0 commit = False tag = False diff --git a/README.rst b/README.rst index a6054363..150effb5 100644 --- a/README.rst +++ b/README.rst @@ -9,7 +9,13 @@ PyGreSQL should run on most platforms where PostgreSQL and Python is running. It is based on the PyGres95 code written by Pascal Andre. D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -Starting with version 5.0, PyGreSQL also supports Python 3. + +The following Python versions are supported: + +* PyGreSQL 4.x and earlier: Python 2 only +* PyGreSQL 5.x: Python 2 and Python 3 +* PyGreSQL 6.x and newer: Python 3 only + Installation ------------ From fe83a9eb73dfe6288963615fadc79d0ab6d5e497 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 21:02:30 +0200 Subject: [PATCH 127/194] Simplify one more test module --- tests/test_dbapi20_copy.py | 77 ++++++++++---------------------------- 1 file changed, 20 insertions(+), 57 deletions(-) diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index d8661251..540ccf1e 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -17,11 +17,6 @@ from .config import dbname, dbhost, dbport, dbuser, dbpasswd -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - class InputStream: @@ -248,26 +243,12 @@ def test_input_string_multiple_rows(self): self.check_table() self.check_rowcount() - if str is unicode: # Python >= 3.0 - - def test_input_bytes(self): - self.copy_from(b'42\tHello, world!') - self.assertEqual(self.table_data, [(42, 'Hello, world!')]) - self.truncate_table() - self.copy_from(self.data_text.encode('utf-8')) - self.check_table() - - else: # Python < 3.0 - - def test_input_unicode(self): - if not self.can_encode: - self.skipTest('database does not support utf8') - self.copy_from(u'43\tWürstel, Käse!') - self.assertEqual(self.table_data, [(43, 'Würstel, Käse!')]) - self.truncate_table() - # noinspection PyUnresolvedReferences - self.copy_from(self.data_text.decode('utf-8')) - self.check_table() + def test_input_bytes(self): + self.copy_from(b'42\tHello, world!') + self.assertEqual(self.table_data, [(42, 'Hello, world!')]) + self.truncate_table() + self.copy_from(self.data_text.encode('utf-8')) + self.check_table() def test_input_iterable(self): self.copy_from(self.data_text.splitlines()) @@ -281,12 +262,10 @@ def test_input_iterable_with_newlines(self): self.copy_from('%s\n' % row for row in self.data_text.splitlines()) self.check_table() - if str is unicode: # Python >= 3.0 - - def test_input_iterable_bytes(self): - self.copy_from(row.encode('utf-8') - for row in self.data_text.splitlines()) - self.check_table() + def test_input_iterable_bytes(self): + self.copy_from(row.encode('utf-8') + for row in self.data_text.splitlines()) + self.check_table() def test_sep(self): stream = ('%d-%s' % row for row in self.data) @@ -437,28 +416,14 @@ def test_generator_with_schema_name(self): ret = self.cursor.copy_to(None, 'public.copytest') self.assertEqual(''.join(ret), self.data_text) - if str is unicode: # Python >= 3.0 - - def test_generator_bytes(self): - ret = self.copy_to(decode=False) - self.assertIsInstance(ret, Iterable) - rows = list(ret) - self.assertEqual(len(rows), 3) - rows = b''.join(rows) - self.assertIsInstance(rows, bytes) - self.assertEqual(rows, self.data_text.encode('utf-8')) - - else: # Python < 3.0 - - def test_generator_unicode(self): - ret = self.copy_to(decode=True) - self.assertIsInstance(ret, Iterable) - rows = list(ret) - self.assertEqual(len(rows), 3) - rows = ''.join(rows) - self.assertIsInstance(rows, unicode) - # noinspection PyUnresolvedReferences - self.assertEqual(rows, self.data_text.decode('utf-8')) + def test_generator_bytes(self): + ret = self.copy_to(decode=False) + self.assertIsInstance(ret, Iterable) + rows = list(ret) + self.assertEqual(len(rows), 3) + rows = b''.join(rows) + self.assertIsInstance(rows, bytes) + self.assertEqual(rows, self.data_text.encode('utf-8')) def test_rowcount_increment(self): ret = self.copy_to() @@ -470,7 +435,7 @@ def test_decode(self): ret_raw = b''.join(self.copy_to(decode=False)) ret_decoded = ''.join(self.copy_to(decode=True)) self.assertIsInstance(ret_raw, bytes) - self.assertIsInstance(ret_decoded, unicode) + self.assertIsInstance(ret_decoded, str) self.assertEqual(ret_decoded, ret_raw.decode('utf-8')) self.check_rowcount() @@ -556,9 +521,7 @@ def test_file(self): ret = self.copy_to(stream) self.assertIs(ret, self.cursor) self.assertEqual(str(stream), self.data_text) - data = self.data_text - if str is unicode: # Python >= 3.0 - data = data.encode('utf-8') + data = self.data_text.encode('utf-8') sizes = [len(row) + 1 for row in data.splitlines()] self.assertEqual(stream.sizes, sizes) self.check_rowcount() From 07da5d802bff07601d2c62b537b2a6214484285e Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 21:53:58 +0200 Subject: [PATCH 128/194] Do not use wildcard imports --- pg.py | 35 +++++++++++++++++++++++++++++--- pgdb.py | 18 ++++++++++++++-- tests/test_classic.py | 5 ++++- tests/test_classic_connection.py | 1 - tests/test_dbapi20.py | 2 -- tests/test_dbapi20_copy.py | 2 +- 6 files changed, 53 insertions(+), 10 deletions(-) diff --git a/pg.py b/pg.py index 2edeb6e2..99e4aa62 100644 --- a/pg.py +++ b/pg.py @@ -21,7 +21,7 @@ """ try: - from _pg import * + from _pg import version except ImportError as e: import os libpq = 'libpq.' @@ -35,10 +35,11 @@ for path in paths: with os.add_dll_directory(os.path.abspath(path)): try: - from _pg import * + from _pg import version except ImportError: pass else: + del version e = None break if paths: @@ -49,6 +50,34 @@ raise ImportError( "Cannot import shared library for PyGreSQL,\n" "probably because no %s is installed.\n%s" % (libpq, e)) from e +else: + del version + +# import objects from extension module +from _pg import ( + Error, Warning, + DataError, DatabaseError, + IntegrityError, InterfaceError, InternalError, + InvalidResultError, MultipleResultsError, + NoResultError, NotSupportedError, + OperationalError, ProgrammingError, + INV_READ, INV_WRITE, + POLLING_OK, POLLING_FAILED, POLLING_READING, POLLING_WRITING, + SEEK_CUR, SEEK_END, SEEK_SET, + TRANS_ACTIVE, TRANS_IDLE, TRANS_INERROR, + TRANS_INTRANS, TRANS_UNKNOWN, + cast_array, cast_hstore, cast_record, + connect, escape_bytea, escape_string, unescape_bytea, + get_array, get_bool, get_bytea_escaped, + get_datestyle, get_decimal, get_decimal_point, + get_defbase, get_defhost, get_defopt, get_defport, get_defuser, + get_jsondecode, get_pqlib_version, + set_array, set_bool, set_bytea_escaped, + set_datestyle, set_decimal, set_decimal_point, + set_defbase, set_defhost, set_defopt, + set_defpasswd, set_defport, set_defuser, + set_jsondecode, set_query_helpers, + version) __version__ = version @@ -72,7 +101,7 @@ 'get_array', 'get_bool', 'get_bytea_escaped', 'get_datestyle', 'get_decimal', 'get_decimal_point', 'get_defbase', 'get_defhost', 'get_defopt', 'get_defport', 'get_defuser', - 'get_jsondecode', 'get_typecast', + 'get_jsondecode', 'get_pqlib_version', 'get_typecast', 'set_array', 'set_bool', 'set_bytea_escaped', 'set_datestyle', 'set_decimal', 'set_decimal_point', 'set_defbase', 'set_defhost', 'set_defopt', diff --git a/pgdb.py b/pgdb.py index 85767e3a..4de78e15 100644 --- a/pgdb.py +++ b/pgdb.py @@ -65,7 +65,7 @@ """ try: - from _pg import * + from _pg import version except ImportError as e: import os libpq = 'libpq.' @@ -79,10 +79,11 @@ for path in paths: with os.add_dll_directory(os.path.abspath(path)): try: - from _pg import * + from _pg import version except ImportError: pass else: + del version e = None break if paths: @@ -93,6 +94,19 @@ raise ImportError( "Cannot import shared library for PyGreSQL,\n" "probably because no %s is installed.\n%s" % (libpq, e)) from e +else: + del version + +# import objects from extension module +from _pg import ( + Error, Warning, + DataError, DatabaseError, + IntegrityError, InterfaceError, InternalError, + NotSupportedError, OperationalError, ProgrammingError, + cast_array, cast_hstore, cast_record, + RESULT_DQL, + connect, unescape_bytea, + version) __version__ = version diff --git a/tests/test_classic.py b/tests/test_classic.py index 375bad3f..799cb6c7 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -6,7 +6,10 @@ from time import sleep from threading import Thread -from pg import * +from pg import ( + DB, NotificationHandler, + Error, DatabaseError, IntegrityError, + NotSupportedError, ProgrammingError) from .config import dbname, dbhost, dbport, dbuser, dbpasswd diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 0dedc5c4..068dd792 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2666,7 +2666,6 @@ def testSetByteaEscaped(self): self.assertEqual(r, b'data') def testSetRowFactorySize(self): - from functools import lru_cache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 8505e518..01a89247 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1,7 +1,6 @@ #!/usr/bin/python import gc -import sys import unittest from datetime import date, time, datetime, timedelta, timezone @@ -1338,7 +1337,6 @@ def test_no_close(self): self.assertEqual(row, data) def test_set_row_factory_size(self): - from functools import lru_cache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] con = self._connect() cur = con.cursor() diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 540ccf1e..769065ab 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -264,7 +264,7 @@ def test_input_iterable_with_newlines(self): def test_input_iterable_bytes(self): self.copy_from(row.encode('utf-8') - for row in self.data_text.splitlines()) + for row in self.data_text.splitlines()) self.check_table() def test_sep(self): From 3123f331e42c6408f11c1cdb49936b168ebaf847 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 22:31:36 +0200 Subject: [PATCH 129/194] Use some default parameters for testing --- tests/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/config.py b/tests/config.py index e6bf326c..6e2ebd3c 100644 --- a/tests/config.py +++ b/tests/config.py @@ -16,9 +16,9 @@ get = environ.get -dbname = get('PYGRESQL_DB', get('PGDATABASE')) -dbhost = get('PYGRESQL_HOST', get('PGHOST')) -dbport = get('PYGRESQL_PORT', get('PGPORT')) +dbname = get('PYGRESQL_DB', get('PGDATABASE', 'test')) +dbhost = get('PYGRESQL_HOST', get('PGHOST', 'localhost')) +dbport = get('PYGRESQL_PORT', get('PGPORT', 5432)) dbuser = get('PYGRESQL_USER', get('PGUSER')) dbpasswd = get('PYGRESQL_PASSWD', get('PGPASSWORD')) From da5bf3eafbfa72c70351b3b8319d35cf97dd05fc Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 23:31:02 +0200 Subject: [PATCH 130/194] Remove most compiler options These only complicated things and nobody used them anyway. Kept only the memory size option because it is not available in PostgreSQL 10 which is still supported. --- docs/contents/install.rst | 26 ++++++-------------- docs/contents/pg/connection.rst | 37 +++++++++++------------------ docs/contents/pg/module.rst | 28 ++++++++++------------ pgconn.c | 26 -------------------- pginternal.c | 4 ---- pgmodule.c | 27 --------------------- pgquery.c | 11 ++++----- setup.py | 42 +++------------------------------ tox.ini | 2 +- 9 files changed, 43 insertions(+), 160 deletions(-) diff --git a/docs/contents/install.rst b/docs/contents/install.rst index d1926881..fd4f99b5 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -91,24 +91,24 @@ Now you should be ready to use PyGreSQL. You can also run the build step separately if you want to create a distribution to be installed on a different system or explicitly enable or disable certain -features. For instance, in order to build PyGreSQL without support for the SSL -info functions, run:: +features. For instance, in order to build PyGreSQL without support for the +memory size functions, run:: - python setup.py build_ext --no-ssl-info + python setup.py build_ext --no-memory-size By default, PyGreSQL is compiled with support for all features available in the installed PostgreSQL version, and you will get warnings for the features that are not supported in this version. You can also explicitly require a feature in order to get an error if it is not available, for instance: - python setup.py build_ext --ssl-info + python setup.py build_ext --memory-size You can find out all possible build options with:: python setup.py build_ext --help Alternatively, you can also use the corresponding C preprocessor macros like -``SSL_INFO`` directly (see the next section). +``MEMORY_SIZE`` directly (see the next section). Note that if you build PyGreSQL with support for newer features that are not available in the PQLib installed on the runtime system, you may get an error @@ -154,13 +154,7 @@ Stand-Alone Some options may be added to this line:: - -DDEFAULT_VARS default variables support - -DDIRECT_ACCESS direct access methods - -DLARGE_OBJECTS large object support - -DESCAPING_FUNCS support for newer escaping functions - -DPQLIB_INFO support PQLib information - -DSSL_INFO support SSL information - -DMEMORY_SIZE support memory size function + -DMEMORY_SIZE = support memory size function (PostgreSQL 12 or newer) On some systems you may need to include ``-lcrypt`` in the list of libraries to make it compile. @@ -202,13 +196,7 @@ Built-in to Python interpreter Some options may be added to this line:: - -DDEFAULT_VARS default variables support - -DDIRECT_ACCESS direct access methods - -DLARGE_OBJECTS large object support - -DESCAPING_FUNCS support for newer escaping functions - -DPQLIB_INFO support PQLib information - -DSSL_INFO support SSL information - -DMEMORY_SIZE support memory size function + -DMEMORY_SIZE = support memory size function (PostgreSQL 12 or newer) On some systems you may need to include ``-lcrypt`` in the list of libraries to make it compile. diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index d1c95213..7fd44cca 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -13,17 +13,8 @@ significant parameters in function calls. Some methods give direct access to the connection socket. *Do not use them unless you really know what you are doing.* - If you prefer disabling them, - do not set the ``direct_access`` option in the Python setup file. - These methods are specified by the tag [DA]. - -.. note:: - - Some other methods give access to large objects - (refer to PostgreSQL user manual for more information about these). - If you want to forbid access to these from the module, - set the ``large_objects`` option in the Python setup file. - These methods are specified by the tag [LO]. + Some other methods give access to large objects. + Refer to the PostgreSQL user manual for more information about these. query -- execute a SQL command string ------------------------------------- @@ -605,8 +596,8 @@ attributes: .. versionadded:: 4.1 -putline -- write a line to the server socket [DA] -------------------------------------------------- +putline -- write a line to the server socket +-------------------------------------------- .. method:: Connection.putline(line) @@ -618,8 +609,8 @@ putline -- write a line to the server socket [DA] This method allows to directly write a string to the server socket. -getline -- get a line from server socket [DA] ---------------------------------------------- +getline -- get a line from server socket +---------------------------------------- .. method:: Connection.getline() @@ -633,8 +624,8 @@ getline -- get a line from server socket [DA] This method allows to directly read a string from the server socket. -endcopy -- synchronize client and server [DA] ---------------------------------------------- +endcopy -- synchronize client and server +---------------------------------------- .. method:: Connection.endcopy() @@ -647,8 +638,8 @@ endcopy -- synchronize client and server [DA] The use of direct access methods may desynchronize client and server. This method ensure that client and server will be synchronized. -locreate -- create a large object in the database [LO] ------------------------------------------------------- +locreate -- create a large object in the database +------------------------------------------------- .. method:: Connection.locreate(mode) @@ -665,8 +656,8 @@ by OR-ing the constants defined in the :mod:`pg` module (:const:`INV_READ`, and :const:`INV_WRITE`). Please refer to PostgreSQL user manual for a description of the mode values. -getlo -- build a large object from given oid [LO] -------------------------------------------------- +getlo -- build a large object from given oid +-------------------------------------------- .. method:: Connection.getlo(oid) @@ -681,8 +672,8 @@ getlo -- build a large object from given oid [LO] This method allows reusing a previously created large object through the :class:`LargeObject` interface, provided the user has its OID. -loimport -- import a file to a large object [LO] ------------------------------------------------- +loimport -- import a file to a large object +------------------------------------------- .. method:: Connection.loimport(name) diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index b122808b..9faa3754 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -10,9 +10,7 @@ the environment variables used by PostgreSQL. These "default variables" were designed to allow you to handle general connection parameters without heavy code in your programs. You can prompt the user for a value, put it in the default variable, and forget it, without -having to modify your environment. The support for default variables can be -disabled by not setting the ``default_vars`` option in the Python setup file. -Methods relative to this are specified by the tag [DV]. +having to modify your environment. All variables are set to ``None`` at module initialization, specifying that standard environment variables should be used. @@ -87,8 +85,8 @@ For example, version 9.1.2 will be returned as 90102. .. versionadded:: 5.2 (needs PostgreSQL >= 9.1) -get/set_defhost -- default server host [DV] -------------------------------------------- +get/set_defhost -- default server host +-------------------------------------- .. function:: get_defhost(host) @@ -117,8 +115,8 @@ If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default host. -get/set_defport -- default server port [DV] -------------------------------------------- +get/set_defport -- default server port +-------------------------------------- .. function:: get_defport() @@ -145,8 +143,8 @@ This methods sets the default port value for new connections. If -1 is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default port. -get/set_defopt -- default connection options [DV] --------------------------------------------------- +get/set_defopt -- default connection options +--------------------------------------------- .. function:: get_defopt() @@ -174,8 +172,8 @@ This methods sets the default connection options value for new connections. If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default options. -get/set_defbase -- default database name [DV] ---------------------------------------------- +get/set_defbase -- default database name +---------------------------------------- .. function:: get_defbase() @@ -203,8 +201,8 @@ This method sets the default database name value for new connections. If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default host. -get/set_defuser -- default database user [DV] ---------------------------------------------- +get/set_defuser -- default database user +---------------------------------------- .. function:: get_defuser() @@ -232,8 +230,8 @@ This method sets the default database user name for new connections. If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default host. -get/set_defpasswd -- default database password [DV] ---------------------------------------------------- +get/set_defpasswd -- default database password +---------------------------------------------- .. function:: get_defpasswd() diff --git a/pgconn.c b/pgconn.c index 910f2212..c67e74dc 100644 --- a/pgconn.c +++ b/pgconn.c @@ -94,27 +94,17 @@ conn_getattr(connObject *self, PyObject *nameobj) /* whether the connection uses SSL */ if (!strcmp(name, "ssl_in_use")) { -#ifdef SSL_INFO if (PQsslInUse(self->cnx)) { Py_INCREF(Py_True); return Py_True; } else { Py_INCREF(Py_False); return Py_False; } -#else - set_error_msg(NotSupportedError, "SSL info functions not supported"); - return NULL; -#endif } /* SSL attributes */ if (!strcmp(name, "ssl_attributes")) { -#ifdef SSL_INFO return get_ssl_attributes(self->cnx); -#else - set_error_msg(NotSupportedError, "SSL info functions not supported"); - return NULL; -#endif } return PyObject_GenericGetAttr((PyObject *) self, nameobj); @@ -540,7 +530,6 @@ conn_describe_prepared(connObject *self, PyObject *args) return NULL; /* error */ } -#ifdef DIRECT_ACCESS static char conn_putline__doc__[] = "putline(line) -- send a line directly to the backend"; @@ -697,7 +686,6 @@ conn_is_non_blocking(connObject *self, PyObject *noargs) return PyBool_FromLong((long)rc); } -#endif /* DIRECT_ACCESS */ /* Insert table */ @@ -1110,8 +1098,6 @@ conn_date_format(connObject *self, PyObject *noargs) return PyUnicode_FromString(fmt); } -#ifdef ESCAPING_FUNCS - /* Escape literal */ static char conn_escape_literal__doc__[] = "escape_literal(str) -- escape a literal constant for use within SQL"; @@ -1202,8 +1188,6 @@ conn_escape_identifier(connObject *self, PyObject *string) return to_obj; } -#endif /* ESCAPING_FUNCS */ - /* Escape string */ static char conn_escape_string__doc__[] = "escape_string(str) -- escape a string for use within SQL"; @@ -1299,8 +1283,6 @@ conn_escape_bytea(connObject *self, PyObject *data) return to_obj; } -#ifdef LARGE_OBJECTS - /* Constructor for large objects (internal use only) */ static largeObject * large_new(connObject *pgcnx, Oid oid) @@ -1415,8 +1397,6 @@ conn_loimport(connObject *self, PyObject *args) return (PyObject *) large_new(self, lo_oid); } -#endif /* LARGE_OBJECTS */ - /* Reset connection. */ static char conn_reset__doc__[] = "reset() -- reset connection with current parameters\n\n" @@ -1724,18 +1704,15 @@ static struct PyMethodDef conn_methods[] = { {"date_format", (PyCFunction) conn_date_format, METH_NOARGS, conn_date_format__doc__}, -#ifdef ESCAPING_FUNCS {"escape_literal", (PyCFunction) conn_escape_literal, METH_O, conn_escape_literal__doc__}, {"escape_identifier", (PyCFunction) conn_escape_identifier, METH_O, conn_escape_identifier__doc__}, -#endif /* ESCAPING_FUNCS */ {"escape_string", (PyCFunction) conn_escape_string, METH_O, conn_escape_string__doc__}, {"escape_bytea", (PyCFunction) conn_escape_bytea, METH_O, conn_escape_bytea__doc__}, -#ifdef DIRECT_ACCESS {"putline", (PyCFunction) conn_putline, METH_VARARGS, conn_putline__doc__}, {"getline", (PyCFunction) conn_getline, @@ -1746,16 +1723,13 @@ static struct PyMethodDef conn_methods[] = { METH_VARARGS, conn_set_non_blocking__doc__}, {"is_non_blocking", (PyCFunction) conn_is_non_blocking, METH_NOARGS, conn_is_non_blocking__doc__}, -#endif /* DIRECT_ACCESS */ -#ifdef LARGE_OBJECTS {"locreate", (PyCFunction) conn_locreate, METH_VARARGS, conn_locreate__doc__}, {"getlo", (PyCFunction) conn_getlo, METH_VARARGS, conn_getlo__doc__}, {"loimport", (PyCFunction) conn_loimport, METH_VARARGS, conn_loimport__doc__}, -#endif /* LARGE_OBJECTS */ {NULL, NULL} /* sentinel */ }; diff --git a/pginternal.c b/pginternal.c index 50181b0d..61446f41 100644 --- a/pginternal.c +++ b/pginternal.c @@ -1115,8 +1115,6 @@ set_error(PyObject *type, const char * msg, PGconn *cnx, PGresult *result) set_error_msg_and_state(type, msg, encoding, sqlstate); } -#ifdef SSL_INFO - /* Get SSL attributes and values as a dictionary. */ static PyObject * get_ssl_attributes(PGconn *cnx) { @@ -1144,8 +1142,6 @@ get_ssl_attributes(PGconn *cnx) { return attr_dict; } -#endif /* SSL_INFO */ - /* Format result (mostly useful for debugging). Note: This is similar to the Postgres function PQprint(). PQprint() is not used because handing over a stream from Python to diff --git a/pgmodule.c b/pgmodule.c index 6adc79c0..f1335263 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -59,14 +59,12 @@ static const char *PyPgVersion = TOSTRING(PYGRESQL_VERSION); /* MODULE GLOBAL VARIABLES */ -#ifdef DEFAULT_VARS static PyObject *pg_default_host; /* default database host */ static PyObject *pg_default_base; /* default database name */ static PyObject *pg_default_opt; /* default connection options */ static PyObject *pg_default_port; /* default connection port */ static PyObject *pg_default_user; /* default username */ static PyObject *pg_default_passwd; /* default password */ -#endif /* DEFAULT_VARS */ static PyObject *decimal = NULL, /* decimal type */ *dictiter = NULL, /* function for getting dict results */ @@ -160,7 +158,6 @@ typedef struct } queryObject; #define is_queryObject(v) (PyType(v) == &queryType) -#ifdef LARGE_OBJECTS typedef struct { PyObject_HEAD @@ -169,7 +166,6 @@ typedef struct int lo_fd; /* large object fd */ } largeObject; #define is_largeObject(v) (PyType(v) == &largeType) -#endif /* LARGE_OBJECTS */ /* Internal functions */ #include "pginternal.c" @@ -187,9 +183,7 @@ typedef struct #include "pgnotice.c" /* Large objects */ -#ifdef LARGE_OBJECTS #include "pglarge.c" -#endif /* MODULE FUNCTIONS */ @@ -228,7 +222,6 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) return NULL; } -#ifdef DEFAULT_VARS /* handles defaults variables (for uninitialised vars) */ if ((!pghost) && (pg_default_host != Py_None)) pghost = PyBytes_AsString(pg_default_host); @@ -247,7 +240,6 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) if ((!pgpasswd) && (pg_default_passwd != Py_None)) pgpasswd = PyBytes_AsString(pg_default_passwd); -#endif /* DEFAULT_VARS */ if (!(conn_obj = PyObject_New(connObject, &connType))) { set_error_msg(InternalError, "Can't create new connection object"); @@ -309,8 +301,6 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) return (PyObject *) conn_obj; } -#ifdef PQLIB_INFO - /* Get version of libpq that is being used */ static char pg_get_pqlib_version__doc__[] = "get_pqlib_version() -- get the version of libpq that is being used"; @@ -320,8 +310,6 @@ pg_get_pqlib_version(PyObject *self, PyObject *noargs) { return PyLong_FromLong(PQlibVersion()); } -#endif /* PQLIB_INFO */ - /* Escape string */ static char pg_escape_string__doc__[] = "escape_string(string) -- escape a string for use within SQL"; @@ -766,8 +754,6 @@ pg_set_jsondecode(PyObject *self, PyObject *func) return ret; } -#ifdef DEFAULT_VARS - /* Get default host. */ static char pg_get_defhost__doc__[] = "get_defhost() -- return default database host"; @@ -1012,7 +998,6 @@ pg_set_defport(PyObject *self, PyObject *args) return old; } -#endif /* DEFAULT_VARS */ /* Cast a string with a text representation of an array to a list. */ static char pg_cast_array__doc__[] = @@ -1216,7 +1201,6 @@ static struct PyMethodDef pg_methods[] = { METH_VARARGS|METH_KEYWORDS, pg_cast_record__doc__}, {"cast_hstore", (PyCFunction) pg_cast_hstore, METH_O, pg_cast_hstore__doc__}, -#ifdef DEFAULT_VARS {"get_defhost", pg_get_defhost, METH_NOARGS, pg_get_defhost__doc__}, {"set_defhost", pg_set_defhost, METH_VARARGS, pg_set_defhost__doc__}, {"get_defbase", pg_get_defbase, METH_NOARGS, pg_get_defbase__doc__}, @@ -1228,11 +1212,8 @@ static struct PyMethodDef pg_methods[] = { {"get_defuser", pg_get_defuser, METH_NOARGS, pg_get_defuser__doc__}, {"set_defuser", pg_set_defuser, METH_VARARGS, pg_set_defuser__doc__}, {"set_defpasswd", pg_set_defpasswd, METH_VARARGS, pg_set_defpasswd__doc__}, -#endif /* DEFAULT_VARS */ -#ifdef PQLIB_INFO {"get_pqlib_version", (PyCFunction) pg_get_pqlib_version, METH_NOARGS, pg_get_pqlib_version__doc__}, -#endif /* PQLIB_INFO */ {NULL, NULL} /* sentinel */ }; @@ -1260,17 +1241,13 @@ PyMODINIT_FUNC PyInit__pg(void) /* Initialize here because some Windows platforms get confused otherwise */ connType.tp_base = noticeType.tp_base = queryType.tp_base = sourceType.tp_base = &PyBaseObject_Type; -#ifdef LARGE_OBJECTS largeType.tp_base = &PyBaseObject_Type; -#endif if (PyType_Ready(&connType) || PyType_Ready(¬iceType) || PyType_Ready(&queryType) || PyType_Ready(&sourceType) -#ifdef LARGE_OBJECTS || PyType_Ready(&largeType) -#endif ) { return NULL; @@ -1354,7 +1331,6 @@ PyMODINIT_FUNC PyInit__pg(void) PyDict_SetItemString(dict, "POLLING_READING", PyLong_FromLong(PGRES_POLLING_READING)); PyDict_SetItemString(dict, "POLLING_WRITING", PyLong_FromLong(PGRES_POLLING_WRITING)); -#ifdef LARGE_OBJECTS /* Create mode for large objects */ PyDict_SetItemString(dict, "INV_READ", PyLong_FromLong(INV_READ)); PyDict_SetItemString(dict, "INV_WRITE", PyLong_FromLong(INV_WRITE)); @@ -1363,9 +1339,7 @@ PyMODINIT_FUNC PyInit__pg(void) PyDict_SetItemString(dict, "SEEK_SET", PyLong_FromLong(SEEK_SET)); PyDict_SetItemString(dict, "SEEK_CUR", PyLong_FromLong(SEEK_CUR)); PyDict_SetItemString(dict, "SEEK_END", PyLong_FromLong(SEEK_END)); -#endif /* LARGE_OBJECTS */ -#ifdef DEFAULT_VARS /* Prepare default values */ Py_INCREF(Py_None); pg_default_host = Py_None; @@ -1379,7 +1353,6 @@ PyMODINIT_FUNC PyInit__pg(void) pg_default_user = Py_None; Py_INCREF(Py_None); pg_default_passwd = Py_None; -#endif /* DEFAULT_VARS */ /* Store common pg encoding ids */ diff --git a/pgquery.c b/pgquery.c index 0923eb66..1196889a 100644 --- a/pgquery.c +++ b/pgquery.c @@ -246,18 +246,19 @@ query_next(queryObject *self, PyObject *noargs) return row_tuple; } -#ifdef MEMORY_SIZE - /* Get number of bytes allocated for PGresult object */ static char query_memsize__doc__[] = "memsize() -- return number of bytes allocated by query result"; static PyObject * query_memsize(queryObject *self, PyObject *noargs) { +#ifdef MEMORY_SIZE return PyLong_FromSize_t(PQresultMemorySize(self->result)); -} - +#else + set_error_msg(NotSupportedError, "Memory size functions not supported"); + return NULL; #endif /* MEMORY_SIZE */ +} /* Get number of rows. */ static char query_ntuples__doc__[] = @@ -949,10 +950,8 @@ static struct PyMethodDef query_methods[] = { METH_VARARGS, query_fieldinfo__doc__}, {"ntuples", (PyCFunction) query_ntuples, METH_NOARGS, query_ntuples__doc__}, -#ifdef MEMORY_SIZE {"memsize", (PyCFunction) query_memsize, METH_NOARGS, query_memsize__doc__}, -#endif /* MEMORY_SIZE */ {NULL, NULL} }; diff --git a/setup.py b/setup.py index fb5330e8..456e3b5e 100755 --- a/setup.py +++ b/setup.py @@ -104,31 +104,13 @@ class build_pg_ext(build_ext): user_options = build_ext.user_options + [ ('strict', None, "count all compiler warnings as errors"), - ('direct-access', None, "enable direct access functions"), - ('no-direct-access', None, "disable direct access functions"), - ('direct-access', None, "enable direct access functions"), - ('no-direct-access', None, "disable direct access functions"), - ('large-objects', None, "enable large object support"), - ('no-large-objects', None, "disable large object support"), - ('default-vars', None, "enable default variables use"), - ('no-default-vars', None, "disable default variables use"), - ('escaping-funcs', None, "enable string escaping functions"), - ('no-escaping-funcs', None, "disable string escaping functions"), - ('ssl-info', None, "use new ssl info functions"), - ('no-ssl-info', None, "do not use new ssl info functions"), - ('memory-size', None, "enable new memory size function"), - ('no-memory-size', None, "disable new memory size function")] + ('memory-size', None, "enable memory size function"), + ('no-memory-size', None, "disable memory size function")] boolean_options = build_ext.boolean_options + [ - 'strict', 'direct-access', 'large-objects', 'default-vars', - 'escaping-funcs', 'ssl-info', 'memory-size'] + 'strict', 'memory-size'] negative_opt = { - 'no-direct-access': 'direct-access', - 'no-large-objects': 'large-objects', - 'no-default-vars': 'default-vars', - 'no-escaping-funcs': 'escaping-funcs', - 'no-ssl-info': 'ssl-info', 'no-memory-size': 'memory-size'} def get_compiler(self): @@ -138,12 +120,6 @@ def get_compiler(self): def initialize_options(self): build_ext.initialize_options(self) self.strict = False - self.direct_access = None - self.large_objects = None - self.default_vars = None - self.escaping_funcs = None - self.pqlib_info = None - self.ssl_info = None self.memory_size = None supported = pg_version >= (10, 0) if not supported: @@ -155,18 +131,6 @@ def finalize_options(self): build_ext.finalize_options(self) if self.strict: extra_compile_args.append('-Werror') - if self.direct_access is None or self.direct_access: - define_macros.append(('DIRECT_ACCESS', None)) - if self.large_objects is None or self.large_objects: - define_macros.append(('LARGE_OBJECTS', None)) - if self.default_vars is None or self.default_vars: - define_macros.append(('DEFAULT_VARS', None)) - if self.escaping_funcs is None or self.escaping_funcs: - define_macros.append(('ESCAPING_FUNCS', None)) - if self.pqlib_info is None or self.pqlib_info: - define_macros.append(('PQLIB_INFO', None)) - if self.ssl_info is None or self.ssl_info: - define_macros.append(('SSL_INFO', None)) wanted = self.memory_size supported = pg_version >= (12, 0) if (wanted is None and supported) or wanted: diff --git a/tox.ini b/tox.ini index d48b44c7..23fb9379 100644 --- a/tox.ini +++ b/tox.ini @@ -21,5 +21,5 @@ passenv = PG* PYGRESQL_* commands = - python setup.py clean --all build_ext --force --inplace --strict --ssl-info --memory-size + python setup.py clean --all build_ext --force --inplace --strict --memory-size python -m unittest {posargs:discover} From 99dd2e07af24a299e39dbea25eee2ee3a09b3342 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 23:52:50 +0200 Subject: [PATCH 131/194] Remove support for very old PostgreSQL versions --- docs/contents/pg/connection.rst | 4 +- docs/contents/pg/module.rst | 4 +- pg.py | 80 +++++++++------------------ pgdb.py | 17 ++---- tests/test_classic_connection.py | 8 +-- tests/test_classic_dbwrapper.py | 93 +++++--------------------------- 6 files changed, 50 insertions(+), 156 deletions(-) diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 7fd44cca..237e25a8 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -749,10 +749,10 @@ the connection and its status. These attributes are: this is True if the connection uses SSL, False if not -.. versionadded:: 5.1 (needs PostgreSQL >= 9.5) +.. versionadded:: 5.1 .. attribute:: Connection.ssl_attributes SSL-related information about the connection (dict) -.. versionadded:: 5.1 (needs PostgreSQL >= 9.5) +.. versionadded:: 5.1 diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index 9faa3754..203ada03 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -81,9 +81,9 @@ get_pqlib_version -- get the version of libpq The number is formed by converting the major, minor, and revision numbers of the libpq version into two-decimal-digit numbers and appending them together. -For example, version 9.1.2 will be returned as 90102. +For example, version 15.4 will be returned as 150400. -.. versionadded:: 5.2 (needs PostgreSQL >= 9.1) +.. versionadded:: 5.2 get/set_defhost -- default server host -------------------------------------- diff --git a/pg.py b/pg.py index 99e4aa62..896af911 100644 --- a/pg.py +++ b/pg.py @@ -1126,19 +1126,11 @@ def __init__(self, db): self._typecasts = Typecasts() self._typecasts.get_attnames = self.get_attnames self._typecasts.connection = self._db - if db.server_version < 80400: - # very old remote databases (not officially supported) - self._query_pg_type = ( - "SELECT oid, typname, oid::pg_catalog.regtype," - " typlen, typtype, null as typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type" - " WHERE oid OPERATOR(pg_catalog.=) %s::pg_catalog.regtype") - else: - self._query_pg_type = ( - "SELECT oid, typname, oid::pg_catalog.regtype," - " typlen, typtype, typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type" - " WHERE oid OPERATOR(pg_catalog.=) %s::pg_catalog.regtype") + self._query_pg_type = ( + "SELECT oid, typname, oid::pg_catalog.regtype," + " typlen, typtype, typcategory, typdelim, typrelid" + " FROM pg_catalog.pg_type" + " WHERE oid OPERATOR(pg_catalog.=) {}::pg_catalog.regtype") def add(self, oid, pgtype, regtype, typlen, typtype, category, delim, relid): @@ -1162,7 +1154,7 @@ def add(self, oid, pgtype, regtype, def __missing__(self, key): """Get the type info from the database if it is not cached.""" try: - q = self._query_pg_type % (_quote_if_unqualified('$1', key),) + q = self._query_pg_type.format(_quote_if_unqualified('$1', key)) res = self._db.query(q, (key,)).getresult() except ProgrammingError: res = None @@ -1493,33 +1485,17 @@ def __init__(self, *args, **kw): self._privileges = {} self.adapter = Adapter(self) self.dbtypes = DbTypes(self) - if db.server_version < 80400: - # very old remote databases (not officially supported) - self._query_attnames = ( - "SELECT a.attname," - " t.oid, t.typname, t.oid::pg_catalog.regtype," - " t.typlen, t.typtype, null as typcategory," - " t.typdelim, t.typrelid" - " FROM pg_catalog.pg_attribute a" - " JOIN pg_catalog.pg_type t" - " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=)" - " %s::pg_catalog.regclass" - " AND %s AND NOT a.attisdropped ORDER BY a.attnum") - else: - self._query_attnames = ( - "SELECT a.attname," - " t.oid, t.typname, t.oid::pg_catalog.regtype," - " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" - " FROM pg_catalog.pg_attribute a" - " JOIN pg_catalog.pg_type t" - " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=)" - " %s::pg_catalog.regclass" - " AND %s AND NOT a.attisdropped ORDER BY a.attnum") - if db.server_version < 100000: - self._query_generated = None - elif db.server_version < 120000: + self._query_attnames = ( + "SELECT a.attname," + " t.oid, t.typname, t.oid::pg_catalog.regtype," + " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" + " FROM pg_catalog.pg_attribute a" + " JOIN pg_catalog.pg_type t" + " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" + " WHERE a.attrelid OPERATOR(pg_catalog.=)" + " {}::pg_catalog.regclass" + " AND {} AND NOT a.attisdropped ORDER BY a.attnum") + if db.server_version < 120000: self._query_generated = ( "a.attidentity OPERATOR(pg_catalog.=) 'a'" ) @@ -2052,8 +2028,9 @@ def get_attnames(self, table, with_oid=True, flush=False): except KeyError: # cache miss, check the database q = "a.attnum OPERATOR(pg_catalog.>) 0" if with_oid: - q = "(%s OR a.attname OPERATOR(pg_catalog.=) 'oid')" % q - q = self._query_attnames % (_quote_if_unqualified('$1', table), q) + q = f"({q} OR a.attname OPERATOR(pg_catalog.=) 'oid')" + q = self._query_attnames.format( + _quote_if_unqualified('$1', table), q) names = self.db.query(q, (table,)).getresult() types = self.dbtypes names = ((name[0], types.add(*name[1:])) for name in names) @@ -2070,9 +2047,6 @@ def get_generated(self, table, flush=False): be flushed. This may be necessary after the database schema or the search path has been changed. """ - query_generated = self._query_generated - if not query_generated: - return frozenset() generated = self._generated if flush: generated.clear() @@ -2080,8 +2054,10 @@ def get_generated(self, table, flush=False): try: # cache lookup names = generated[table] except KeyError: # cache miss, check the database - q = "a.attnum OPERATOR(pg_catalog.>) 0 AND " + query_generated - q = self._query_attnames % (_quote_if_unqualified('$1', table), q) + q = "a.attnum OPERATOR(pg_catalog.>) 0" + q = f"{q} AND {self._query_generated}" + q = self._query_attnames.format( + _quote_if_unqualified('$1', table), q) names = self.db.query(q, (table,)).getresult() names = frozenset(name[0] for name in names) generated[table] = names # cache it @@ -2394,13 +2370,7 @@ def upsert(self, table, row=None, **kw): ' ON CONFLICT (%s) DO %s RETURNING %s') % ( self._escape_qualified_name(table), names, values, target, do, ret) self._do_debug(q, params) - try: - q = self.db.query(q, params) - except ProgrammingError: - if self.server_version < 90500: - raise _prg_error( - 'Upsert operation is not supported by PostgreSQL version') - raise # re-raise original error + q = self.db.query(q, params) res = q.dictresult() if res: # may be empty with "do nothing" for n, value in res[0].items(): diff --git a/pgdb.py b/pgdb.py index 4de78e15..f986242f 100644 --- a/pgdb.py +++ b/pgdb.py @@ -629,17 +629,10 @@ def __init__(self, cnx): self._typecasts = LocalTypecasts() self._typecasts.get_fields = self.get_fields self._typecasts.connection = cnx - if cnx.server_version < 80400: - # older remote databases (not officially supported) - self._query_pg_type = ( - "SELECT oid, typname," - " typlen, typtype, null as typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) %s") - else: - self._query_pg_type = ( - "SELECT oid, typname," - " typlen, typtype, typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) %s") + self._query_pg_type = ( + "SELECT oid, typname," + " typlen, typtype, typcategory, typdelim, typrelid" + " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) {}") def __missing__(self, key): """Get the type info from the database if it is not cached.""" @@ -650,7 +643,7 @@ def __missing__(self, key): key = '"%s"' % (key,) oid = "'%s'::pg_catalog.regtype" % (self._escape_string(key),) try: - self._src.execute(self._query_pg_type % (oid,)) + self._src.execute(self._query_pg_type.format(oid)) except ProgrammingError: res = None else: diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 068dd792..c456b4ec 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -176,7 +176,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 160000) + self.assertTrue(100000 <= server_version < 160000) def testAttributeSocket(self): socket = self.connection.socket @@ -2704,11 +2704,7 @@ def setUpClass(cls): query = db.query query('set client_encoding=sql_ascii') query('set standard_conforming_strings=off') - try: - query('set bytea_output=escape') - except pg.ProgrammingError: - if db.server_version >= 90000: - raise # ignore for older server versions + query('set bytea_output=escape') db.close() cls.cls_set_up = True diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 25c3c11d..3d372ad3 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -254,7 +254,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 160000) + self.assertTrue(100000 <= server_version < 160000) self.assertEqual(server_version, self.db.db.server_version) def testAttributeSocket(self): @@ -456,11 +456,7 @@ def setUp(self): query("set lc_monetary='C'") query("set datestyle='ISO,YMD'") query('set standard_conforming_strings=on') - try: - query('set bytea_output=hex') - except pg.ProgrammingError: - if self.db.server_version >= 90000: - raise # ignore for older server versions + query('set bytea_output=hex') def tearDown(self): self.doCleanups() @@ -1951,13 +1947,7 @@ def testInsertIntoView(self): r = query(q).getresult() self.assertEqual(r, [(1234, 'abcd')]) r = dict(i4=5678, v4='efgh') - try: - insert('test_view', r) - except (pg.OperationalError, pg.NotSupportedError) as error: - if self.db.server_version < 90300: - # must setup rules in older PostgreSQL versions - self.skipTest('database cannot insert into view') - self.fail(str(error)) + insert('test_view', r) self.assertNotIn('i2', r) self.assertEqual(r['i4'], 5678) self.assertNotIn('i8', r) @@ -2203,12 +2193,7 @@ def testUpsert(self): table = 'upsert_test_table' self.createTable(table, 'n integer primary key, t text') s = dict(n=1, t='x') - try: - r = upsert(table, s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['n'], 1) self.assertEqual(r['t'], 'x') @@ -2296,12 +2281,7 @@ def testUpsertWithOids(self): self.assertIn('m', self.db.get_attnames('test_table', flush=True)) self.assertEqual('n', self.db.pkey('test_table', flush=True)) s = dict(n=2) - try: - r = upsert('test_table', s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = upsert('test_table', s) self.assertIs(r, s) self.assertEqual(r['n'], 2) self.assertIsNone(r['m']) @@ -2366,12 +2346,7 @@ def testUpsertWithCompositeKey(self): self.createTable( table, 'n integer, m integer, t text, primary key (n, m)') s = dict(n=1, m=2, t='x') - try: - r = upsert(table, s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['n'], 1) self.assertEqual(r['m'], 2) @@ -2433,12 +2408,7 @@ def testUpsertWithQuotedNames(self): self.createTable(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'} - try: - r = upsert(table, s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['Prime!'], 31) self.assertEqual(r['much space'], 9009) @@ -2456,8 +2426,6 @@ def testUpsertWithQuotedNames(self): self.assertEqual(r, [(31, 9009, 'No.')]) def testUpsertWithGeneratedColumns(self): - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') upsert = self.db.upsert get = self.db.get server_version = self.db.server_version @@ -3378,12 +3346,7 @@ def testUpsertBytea(self): self.createTable('bytea_test', 'n smallint primary key, data bytea') s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n" r = dict(n=7, data=s) - try: - r = self.db.upsert('bytea_test', r) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = self.db.upsert('bytea_test', r) self.assertIsInstance(r, dict) self.assertIn('n', r) self.assertEqual(r['n'], 7) @@ -3402,12 +3365,7 @@ def testUpsertBytea(self): self.assertIsNone(r['data']) def testInsertGetJson(self): - try: - self.createTable('json_test', 'n smallint primary key, data json') - except pg.ProgrammingError as error: - if self.db.server_version < 90200: - self.skipTest('database does not support json') - self.fail(str(error)) + self.createTable('json_test', 'n smallint primary key, data json') jsondecode = pg.get_jsondecode() # insert null value r = self.db.insert('json_test', n=0, data=None) @@ -3471,13 +3429,8 @@ def testInsertGetJson(self): self.assertEqual(r[0][0], r[1][0]) def testInsertGetJsonb(self): - try: - self.createTable('jsonb_test', - 'n smallint primary key, data jsonb') - except pg.ProgrammingError as error: - if self.db.server_version < 90400: - self.skipTest('database does not support jsonb') - self.fail(str(error)) + self.createTable('jsonb_test', + 'n smallint primary key, data jsonb') jsondecode = pg.get_jsondecode() # insert null value r = self.db.insert('jsonb_test', n=0, data=None) @@ -3703,13 +3656,7 @@ def testArrayOfBytea(self): self.assertNotEqual(r['data'], data) def testArrayOfJson(self): - try: - self.createTable( - 'arraytest', 'id serial primary key, data json[]') - except pg.ProgrammingError as error: - if self.db.server_version < 90200: - self.skipTest('database does not support json') - self.fail(str(error)) + self.createTable('arraytest', 'id serial primary key, data json[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'json[]') data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')] @@ -3751,13 +3698,7 @@ def testArrayOfJson(self): self.assertEqual(r, '{NULL,NULL}') def testArrayOfJsonb(self): - try: - self.createTable( - 'arraytest', 'id serial primary key, data jsonb[]') - except pg.ProgrammingError as error: - if self.db.server_version < 90400: - self.skipTest('database does not support jsonb') - self.fail(str(error)) + self.createTable('arraytest', 'id serial primary key, data jsonb[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]') data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')] @@ -3941,13 +3882,7 @@ def testRecordInsertBytea(self): def testRecordInsertJson(self): query = self.db.query - try: - query('create type test_person_type as' - ' (name text, data json)') - except pg.ProgrammingError as error: - if self.db.server_version < 90200: - self.skipTest('database does not support json') - self.fail(str(error)) + query('create type test_person_type as (name text, data json)') self.addCleanup(query, 'drop type test_person_type') self.createTable('test_person', 'person test_person_type', temporary=False) From 7e53673a2cfe541983740b01ba99bec210830acc Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 30 Aug 2023 01:07:00 +0200 Subject: [PATCH 132/194] Modernize string formatting in pg and pgdb --- pg.py | 192 +++++++++++++++++++++++++++++--------------------------- pgdb.py | 95 ++++++++++++++-------------- 2 files changed, 150 insertions(+), 137 deletions(-) diff --git a/pg.py b/pg.py index 896af911..50f22425 100644 --- a/pg.py +++ b/pg.py @@ -49,7 +49,7 @@ if e: raise ImportError( "Cannot import shared library for PyGreSQL,\n" - "probably because no %s is installed.\n%s" % (libpq, e)) from e + f"probably because no {libpq} is installed.\n{e}") from e else: del version @@ -148,7 +148,7 @@ def _timezone_as_offset(tz): def _oid_key(table): """Build oid key from a table name.""" - return 'oid(%s)' % table + return f'oid({table})' class Bytea(bytes): @@ -170,12 +170,12 @@ def _quote(cls, s): return '""' s = s.replace('"', '\\"') if cls._re_quote.search(s): - s = '"%s"' % s + s = f'"{s}"' return s def __str__(self): q = self._quote - return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items()) + return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) class Json: @@ -220,9 +220,9 @@ def __init__(self): for key in keys: self[key] = typ if isinstance(key, str): - self['_%s' % key] = '%s[]' % typ + self[f'_{key}'] = f'{typ}[]' elif not isinstance(key, tuple): - self[List[key]] = '%s[]' % typ + self[List[key]] = f'{typ}[]' @staticmethod def __missing__(key): @@ -248,7 +248,7 @@ def _quote_if_unqualified(param, name): and must be quoted manually by the caller. """ if isinstance(name, str) and '.' not in name: - return 'quote_ident(%s)' % (param,) + return f'quote_ident({param})' return param @@ -266,7 +266,7 @@ def add(self, value, typ=None): if isinstance(value, Literal): return value self.append(value) - return '$%d' % len(self) + return f'${len(self)}' class Literal(str): @@ -366,7 +366,7 @@ def _adapt_hstore(self, v): return str(v) if isinstance(v, dict): return str(Hstore(v)) - raise TypeError('Hstore parameter %s has wrong type' % v) + raise TypeError(f'Hstore parameter {v} has wrong type') def _adapt_uuid(self, v): """Adapt a UUID parameter.""" @@ -381,14 +381,15 @@ def _adapt_text_array(cls, v): """Adapt a text type array parameter.""" if isinstance(v, list): adapt = cls._adapt_text_array - return '{%s}' % ','.join(adapt(v) for v in v) + return '{' + ','.join(adapt(v) for v in v) + '}' if v is None: return 'null' if not v: return '""' v = str(v) if cls._re_array_quote.search(v): - v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v) + v = cls._re_array_escape.sub(r'\\\1', v) + v = f'"{v}"' return v _adapt_date_array = _adapt_text_array @@ -398,7 +399,7 @@ def _adapt_bool_array(cls, v): """Adapt a boolean array parameter.""" if isinstance(v, list): adapt = cls._adapt_bool_array - return '{%s}' % ','.join(adapt(v) for v in v) + return '{' + ','.join(adapt(v) for v in v) + '}' if v is None: return 'null' if isinstance(v, str): @@ -412,7 +413,7 @@ def _adapt_num_array(cls, v): """Adapt a numeric array parameter.""" if isinstance(v, list): adapt = cls._adapt_num_array - return '{%s}' % ','.join(adapt(v) for v in v) + v = '{' + ','.join(adapt(v) for v in v) + '}' if not v and v != 0: return 'null' return str(v) @@ -433,20 +434,21 @@ def _adapt_json_array(self, v): """Adapt a json array parameter.""" if isinstance(v, list): adapt = self._adapt_json_array - return '{%s}' % ','.join(adapt(v) for v in v) + return '{' + ','.join(adapt(v) for v in v) + '}' if not v: return 'null' if not isinstance(v, str): v = self.db.encode_json(v) if self._re_array_quote.search(v): - v = '"%s"' % self._re_array_escape.sub(r'\\\1', v) + v = self._re_array_escape.sub(r'\\\1', v) + v = f'"{v}"' return v def _adapt_record(self, v, typ): """Adapt a record parameter with given type.""" typ = self.get_attnames(typ).values() if len(typ) != len(v): - raise TypeError('Record parameter %s has wrong size' % v) + raise TypeError(f'Record parameter {v} has wrong size') adapt = self.adapt value = [] for v, t in zip(v, typ): @@ -462,9 +464,11 @@ def _adapt_record(self, v, typ): else: v = str(v) if self._re_record_quote.search(v): - v = '"%s"' % self._re_record_escape.sub(r'\\\1', v) + v = self._re_record_escape.sub(r'\\\1', v) + v = f'"{v}"' value.append(v) - return '(%s)' % ','.join(value) + v = ','.join(value) + return f'({v})' def adapt(self, value, typ=None): """Adapt a value with known database type.""" @@ -483,10 +487,10 @@ def adapt(self, value, typ=None): value = self._adapt_record(value, typ) elif simple.endswith('[]'): if isinstance(value, list): - adapt = getattr(self, '_adapt_%s_array' % simple[:-2]) + adapt = getattr(self, f'_adapt_{simple[:-2]}_array') value = adapt(value) else: - adapt = getattr(self, '_adapt_%s' % simple) + adapt = getattr(self, f'_adapt_{simple}') value = adapt(value) return value @@ -541,7 +545,7 @@ def guess_simple_type(cls, value): if isinstance(value, UUID): return 'uuid' if isinstance(value, list): - return '%s[]' % (cls.guess_simple_base_type(value) or 'text',) + return (cls.guess_simple_base_type(value) or 'text') + '[]' if isinstance(value, tuple): simple_type = cls.simple_type guess = cls.guess_simple_type @@ -578,7 +582,7 @@ def adapt_inline(self, value, nested=False): value = str(value) if isinstance(value, (bytes, str)): value = self.db.escape_string(value) - return "'%s'" % value + return f"'{value}'" if isinstance(value, bool): return 'true' if value else 'false' if isinstance(value, float): @@ -591,21 +595,21 @@ def adapt_inline(self, value, nested=False): return value if isinstance(value, list): q = self.adapt_inline - s = '[%s]' if nested else 'ARRAY[%s]' - return s % ','.join(str(q(v, nested=True)) for v in value) + s = '[{}]' if nested else 'ARRAY[{}]' + return s.format(','.join(str(q(v, nested=True)) for v in value)) if isinstance(value, tuple): q = self.adapt_inline - return '(%s)' % ','.join(str(q(v)) for v in value) + return '({})'.format(','.join(str(q(v)) for v in value)) if isinstance(value, Json): value = self.db.escape_string(str(value)) - return "'%s'::json" % value + return f"'{value}'::json" if isinstance(value, Hstore): value = self.db.escape_string(str(value)) - return "'%s'::hstore" % value + return f"'{value}'::hstore" pg_repr = getattr(value, '__pg_repr__', None) if not pg_repr: raise InterfaceError( - 'Do not know how to adapt type %s' % type(value)) + f'Do not know how to adapt type {type(value)}') value = pg_repr() if isinstance(value, (tuple, list)): value = self.adapt_inline(value) @@ -903,7 +907,7 @@ def cast_interval(value): secs = -secs usecs = -usecs else: - raise ValueError('Cannot parse interval: %s' % value) + raise ValueError(f'Cannot parse interval: {value}') days += 365 * years + 30 * mons return timedelta(days=days, hours=hours, minutes=mins, seconds=secs, microseconds=usecs) @@ -946,7 +950,7 @@ def __missing__(self, typ): but returns None when no special cast function exists. """ if not isinstance(typ, str): - raise TypeError('Invalid type: %s' % typ) + raise TypeError('Invalid type: {typ}') cast = self.defaults.get(typ) if cast: # store default for faster access @@ -992,13 +996,13 @@ def set(self, typ, cast): if cast is None: for t in typ: self.pop(t, None) - self.pop('_%s' % t, None) + self.pop(f'_{t}', None) else: if not callable(cast): raise TypeError("Cast parameter must be callable") for t in typ: self[t] = self._add_connection(cast) - self.pop('_%s' % t, None) + self.pop(f'_{t}', None) def reset(self, typ=None): """Reset the typecasts for the specified type(s) to their defaults. @@ -1027,13 +1031,13 @@ def set_default(cls, typ, cast): if cast is None: for t in typ: defaults.pop(t, None) - defaults.pop('_%s' % t, None) + defaults.pop(f'_{t}', None) else: if not callable(cast): raise TypeError("Cast parameter must be callable") for t in typ: defaults[t] = cast - defaults.pop('_%s' % t, None) + defaults.pop(f'_{t}', None) # noinspection PyMethodMayBeStatic,PyUnusedLocal def get_attnames(self, typ): @@ -1159,7 +1163,7 @@ def __missing__(self, key): except ProgrammingError: res = None if not res: - raise KeyError('Type %s could not be found' % (key,)) + raise KeyError(f'Type {key} could not be found') res = res[0] typ = self.add(*res) self[typ.oid] = self[typ.pgtype] = typ @@ -1224,7 +1228,7 @@ def _row_factory(names): try: return namedtuple('Row', names, rename=True)._make except ValueError: # there is still a problem with the field names - names = ['column_%d' % (n,) for n in range(len(names))] + names = [f'column_{n}' for n in range(len(names))] return namedtuple('Row', names)._make @@ -1335,7 +1339,7 @@ def __init__(self, db, event, callback=None, """ self.db = db self.event = event - self.stop_event = stop_event or 'stop_%s' % event + self.stop_event = stop_event or f'stop_{event}' self.listening = False self.callback = callback if arg_dict is None: @@ -1356,15 +1360,15 @@ def close(self): def listen(self): """Start listening for the event and the stop event.""" if not self.listening: - self.db.query('listen "%s"' % self.event) - self.db.query('listen "%s"' % self.stop_event) + self.db.query(f'listen "{self.event}"') + self.db.query(f'listen "{self.stop_event}"') self.listening = True def unlisten(self): """Stop listening for the event and the stop event.""" if self.listening: - self.db.query('unlisten "%s"' % self.event) - self.db.query('unlisten "%s"' % self.stop_event) + self.db.query(f'unlisten "{self.event}"') + self.db.query(f'unlisten "{self.stop_event}"') self.listening = False def notify(self, db=None, stop=False, payload=None): @@ -1382,9 +1386,10 @@ def notify(self, db=None, stop=False, payload=None): if self.listening: if not db: db = self.db - q = 'notify "%s"' % (self.stop_event if stop else self.event) + event = self.stop_event if stop else self.event + q = f'notify "{event}"' if payload: - q += ", '%s'" % payload + q += f", '{payload}'" return db.query(q) def __call__(self): @@ -1420,8 +1425,9 @@ def __call__(self): if event not in (self.event, self.stop_event): self.unlisten() raise _db_error( - 'Listening for "%s" and "%s", but notified of "%s"' - % (self.event, self.stop_event, event)) + f'Listening for "{self.event}"' + f' and "{self.stop_event}",' + f' but notified of "{event}"') if event == self.stop_event: self.unlisten() self.arg_dict.update(pid=pid, event=event, extra=extra) @@ -1592,7 +1598,7 @@ def _make_bool(d): @staticmethod def _list_params(params): """Create a human readable parameter list.""" - return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1)) + return ', '.join(f'${n}={v!r}' for n, v in enumerate(params, 1)) # Public methods @@ -1735,7 +1741,7 @@ def get_parameter(self, parameter): params.append(param) else: for param in params: - q = 'SHOW %s' % (param,) + q = f'SHOW {param}' value = self.db.query(q).singlescalar() if values is None: values = value @@ -1813,9 +1819,9 @@ def set_parameter(self, parameter, value=None, local=False): local = ' LOCAL' if local else '' for param, value in params.items(): if value is None: - q = 'RESET%s %s' % (local, param) + q = f'RESET{local} {param}' else: - q = 'SET%s %s TO %s' % (local, param, value) + q = f'SET{local} {param} TO {value}' self._do_debug(q) self.db.query(q) @@ -1919,7 +1925,9 @@ def delete_prepared(self, name=None): name. Note that prepared statements are also deallocated automatically when the current session ends. """ - q = "DEALLOCATE %s" % (name or 'ALL',) + if not name: + name = 'ALL' + q = f"DEALLOCATE {name}" self._do_debug(q) return self.db.query(q) @@ -1949,12 +1957,12 @@ def pkey(self, table, composite=False, flush=False): " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" " AND NOT a.attisdropped" " WHERE i.indrelid OPERATOR(pg_catalog.=)" - " %s::pg_catalog.regclass" - " AND i.indisprimary ORDER BY a.attnum") % ( - _quote_if_unqualified('$1', table),) + " {}::pg_catalog.regclass" + " AND i.indisprimary ORDER BY a.attnum").format( + _quote_if_unqualified('$1', table)) pkey = self.db.query(q, (table,)).getresult() if not pkey: - raise KeyError('Table %s has no primary key' % table) + raise KeyError(f'Table {table} has no primary key') # we want to use the order defined in the primary key index here, # not the order as defined by the columns in the table if len(pkey) > 1: @@ -1984,18 +1992,18 @@ def get_relations(self, kinds=None, system=False): """ where = [] if kinds: - where.append("r.relkind IN (%s)" % - ','.join("'%s'" % k for k in kinds)) + where.append( + "r.relkind IN ({})".format(','.join(f"'{k}'" for k in kinds))) if not system: where.append("s.nspname NOT SIMILAR" " TO 'pg/_%|information/_schema' ESCAPE '/'") - where = " WHERE %s" % ' AND '.join(where) if where else '' + where = " WHERE " + ' AND '.join(where) if where else '' q = ("SELECT pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" " FROM pg_catalog.pg_class r" " JOIN pg_catalog.pg_namespace s" - " ON s.oid OPERATOR(pg_catalog.=) r.relnamespace%s" - " ORDER BY s.nspname, r.relname") % where + f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" + " ORDER BY s.nspname, r.relname") return [r[0] for r in self.db.query(q).getresult()] def get_tables(self, system=False): @@ -2089,8 +2097,8 @@ def has_table_privilege(self, table, privilege='select', flush=False): try: # ask cache ret = privileges[table, privilege] except KeyError: # cache miss, ask the database - q = "SELECT pg_catalog.has_table_privilege(%s, $2)" % ( - _quote_if_unqualified('$1', table),) + q = "SELECT pg_catalog.has_table_privilege({}, $2)".format( + _quote_if_unqualified('$1', table)) q = self.db.query(q, (table, privilege)) ret = q.singlescalar() == self._make_bool(True) privileges[table, privilege] = ret # cache it @@ -2130,7 +2138,7 @@ def get(self, table, row, keyname=None): if qoid and isinstance(row, dict) and 'oid' in row: keyname = ('oid',) else: - raise _prg_error('Table %s has no primary key' % table) + raise _prg_error(f'Table {table} has no primary key') else: # the table has a primary key # check whether all key columns have values if isinstance(row, dict) and not set(keyname).issubset(row): @@ -2151,22 +2159,23 @@ def get(self, table, row, keyname=None): adapt = params.add col = self.escape_identifier what = 'oid, *' if qoid else '*' - where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % ( + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( col(k), adapt(row[k], attnames[k])) for k in keyname) if 'oid' in row: if qoid: row[qoid] = row['oid'] del row['oid'] - q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % ( - what, self._escape_qualified_name(table), where) + t = self._escape_qualified_name(table) + q = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() if not res: # make where clause in error message better readable where = where.replace('OPERATOR(pg_catalog.=)', '=') - raise _db_error('No such record in %s\nwhere %s\nwith %s' % ( - table, where, self._list_params(params))) + raise _db_error( + f'No such record in {table}\nwhere {where}\nwith ' + + self._list_params(params)) for n, value in res[0].items(): if qoid and n == 'oid': n = qoid @@ -2208,8 +2217,8 @@ def insert(self, table, row=None, **kw): raise _prg_error('No column found that can be inserted') names, values = ', '.join(names), ', '.join(values) ret = 'oid, *' if qoid else '*' - q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % ( - self._escape_qualified_name(table), names, values, ret) + t = self._escape_qualified_name(table) + q = f'INSERT INTO {t} ({names}) VALUES ({values}) RETURNING {ret}' self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2249,14 +2258,14 @@ def update(self, table, row=None, **kw): try: keyname = self.pkey(table, True) except KeyError: # the table has no primary key - raise _prg_error('Table %s has no primary key' % table) + raise _prg_error(f'Table {table} has no primary key') # check whether all key columns have values if not set(keyname).issubset(row): raise KeyError('Missing value for primary key in row') params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % ( + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( col(k), adapt(row[k], attnames[k])) for k in keyname) if 'oid' in row: if qoid: @@ -2266,13 +2275,14 @@ def update(self, table, row=None, **kw): keyname = set(keyname) for n in attnames: if n in row and n not in keyname and n not in generated: - values.append('%s = %s' % (col(n), adapt(row[n], attnames[n]))) + values.append('{} = {}'.format( + col(n), adapt(row[n], attnames[n]))) if not values: return row values = ', '.join(values) ret = 'oid, *' if qoid else '*' - q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % ( - self._escape_qualified_name(table), values, where, ret) + t = self._escape_qualified_name(table) + q = f'UPDATE {t} SET {values} WHERE {where} RETURNING {ret}' self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2350,7 +2360,7 @@ def upsert(self, table, row=None, **kw): try: keyname = self.pkey(table, True) except KeyError: - raise _prg_error('Table %s has no primary key' % table) + raise _prg_error(f'Table {table} has no primary key') target = ', '.join(col(k) for k in keyname) update = [] keyname = set(keyname) @@ -2360,15 +2370,15 @@ def upsert(self, table, row=None, **kw): value = kw.get(n, n in row) if value: if not isinstance(value, str): - value = 'excluded.%s' % col(n) - update.append('%s = %s' % (col(n), value)) + value = f'excluded.{col(n)}' + update.append(f'{col(n)} = {value}') if not values: return row - do = 'update set %s' % ', '.join(update) if update else 'nothing' + do = 'update set ' + ', '.join(update) if update else 'nothing' ret = 'oid, *' if qoid else '*' - q = ('INSERT INTO %s AS included (%s) VALUES (%s)' - ' ON CONFLICT (%s) DO %s RETURNING %s') % ( - self._escape_qualified_name(table), names, values, target, do, ret) + t = self._escape_qualified_name(table) + q = (f'INSERT INTO {t} AS included ({names}) VALUES ({values})' + f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2435,21 +2445,21 @@ def delete(self, table, row=None, **kw): try: keyname = self.pkey(table, True) except KeyError: # the table has no primary key - raise _prg_error('Table %s has no primary key' % table) + raise _prg_error(f'Table {table} has no primary key') # check whether all key columns have values if not set(keyname).issubset(row): raise KeyError('Missing value for primary key in row') params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % ( + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( col(k), adapt(row[k], attnames[k])) for k in keyname) if 'oid' in row: if qoid: row[qoid] = row['oid'] del row['oid'] - q = 'DELETE FROM %s WHERE %s' % ( - self._escape_qualified_name(table), where) + t = self._escape_qualified_name(table) + q = f'DELETE FROM {t} WHERE {where}' self._do_debug(q, params) res = self.db.query(q, params) return int(res) @@ -2499,7 +2509,7 @@ def truncate(self, table, restart=False, cascade=False, only=False): t = t[:-1].rstrip() t = self._escape_qualified_name(t) if u: - t = 'ONLY %s' % t + t = f'ONLY {t}' tables.append(t) q = ['TRUNCATE', ', '.join(tables)] if restart: @@ -2565,9 +2575,9 @@ def get_as_list(self, table, what=None, where=None, order = ', '.join(map(str, order)) q.extend(['ORDER BY', order]) if limit: - q.append('LIMIT %d' % limit) + q.append(f'LIMIT {limit}') if offset: - q.append('OFFSET %d' % offset) + q.append(f'OFFSET {offset}') q = ' '.join(q) self._do_debug(q) q = self.db.query(q) @@ -2603,7 +2613,7 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, try: keyname = self.pkey(table, True) except (KeyError, ProgrammingError): - raise _prg_error('Table %s has no primary key' % table) + raise _prg_error(f'Table {table} has no primary key') if isinstance(keyname, str): keyname = [keyname] elif not isinstance(keyname, (list, tuple)): @@ -2627,9 +2637,9 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, order = ', '.join(map(str, order)) q.extend(['ORDER BY', order]) if limit: - q.append('LIMIT %d' % limit) + q.append(f'LIMIT {limit}') if offset: - q.append('OFFSET %d' % offset) + q.append(f'OFFSET {offset}') q = ' '.join(q) self._do_debug(q) q = self.db.query(q) diff --git a/pgdb.py b/pgdb.py index f986242f..5e218b42 100644 --- a/pgdb.py +++ b/pgdb.py @@ -93,7 +93,7 @@ if e: raise ImportError( "Cannot import shared library for PyGreSQL,\n" - "probably because no %s is installed.\n%s" % (libpq, e)) from e + f"probably because no {libpq} is installed.\n{e}") from e else: del version @@ -389,7 +389,7 @@ def cast_interval(value): secs = -secs usecs = -usecs else: - raise ValueError('Cannot parse interval: %s' % value) + raise ValueError(f'Cannot parse interval: {value}') days += 365 * years + 30 * mons return timedelta(days=days, hours=hours, minutes=mins, seconds=secs, microseconds=usecs) @@ -429,7 +429,7 @@ def __missing__(self, typ): but returns None when no special cast function exists. """ if not isinstance(typ, str): - raise TypeError('Invalid type: %s' % typ) + raise TypeError(f'Invalid type: {typ}') cast = self.defaults.get(typ) if cast: # store default for faster access @@ -471,13 +471,13 @@ def set(self, typ, cast): if cast is None: for t in typ: self.pop(t, None) - self.pop('_%s' % t, None) + self.pop(f'_{t}', None) else: if not callable(cast): raise TypeError("Cast parameter must be callable") for t in typ: self[t] = self._add_connection(cast) - self.pop('_%s' % t, None) + self.pop(f'_{t}', None) def reset(self, typ=None): """Reset the typecasts for the specified type(s) to their defaults. @@ -495,7 +495,7 @@ def reset(self, typ=None): cast = defaults.get(t) if cast: self[t] = self._add_connection(cast) - t = '_%s' % t + t = f'_{t}' cast = defaults.get(t) if cast: self[t] = self._add_connection(cast) @@ -503,7 +503,7 @@ def reset(self, typ=None): self.pop(t, None) else: self.pop(t, None) - self.pop('_%s' % t, None) + self.pop(f'_{t}', None) def create_array_cast(self, basecast): """Create an array typecast for the given base cast.""" @@ -640,8 +640,8 @@ def __missing__(self, key): oid = key else: if '.' not in key and '"' not in key: - key = '"%s"' % (key,) - oid = "'%s'::pg_catalog.regtype" % (self._escape_string(key),) + key = f'"{key}"' + oid = f"'{self._escape_string(key)}'::pg_catalog.regtype" try: self._src.execute(self._query_pg_type.format(oid)) except ProgrammingError: @@ -649,7 +649,7 @@ def __missing__(self, key): else: res = self._src.fetch(1) if not res: - raise KeyError('Type %s could not be found' % (key,)) + raise KeyError(f'Type {key} could not be found') res = res[0] type_code = TypeCode.create( int(res[0]), res[1], int(res[2]), @@ -676,9 +676,9 @@ def get_fields(self, typ): self._src.execute( "SELECT attname, atttypid" " FROM pg_catalog.pg_attribute" - " WHERE attrelid OPERATOR(pg_catalog.=) %s" + f" WHERE attrelid OPERATOR(pg_catalog.=) {typ.relid}" " AND attnum OPERATOR(pg_catalog.>) 0" - " AND NOT attisdropped ORDER BY attnum" % (typ.relid,)) + " AND NOT attisdropped ORDER BY attnum") return [FieldInfo(name, self.get(int(oid))) for name, oid in self._src.fetch(-1)] @@ -761,7 +761,7 @@ def _row_factory(names): try: return namedtuple('Row', names, rename=True)._make except ValueError: # there is still a problem with the field names - names = ['column_%d' % (n,) for n in range(len(names))] + names = [f'column_{n}' for n in range(len(names))] return namedtuple('Row', names)._make @@ -820,7 +820,7 @@ def _quote(self, value): value = self._cnx.escape_bytea(value).decode('ascii') else: value = self._cnx.escape_string(value) - return "'%s'" % (value,) + return f"'{value}'" if isinstance(value, float): if isinf(value): return "'-Infinity'" if value < 0 else "'Infinity'" @@ -831,18 +831,18 @@ def _quote(self, value): return value if isinstance(value, datetime): if value.tzinfo: - return "'%s'::timestamptz" % (value,) - return "'%s'::timestamp" % (value,) + return f"'{value}'::timestamptz" + return f"'{value}'::timestamp" if isinstance(value, date): - return "'%s'::date" % (value,) + return f"'{value}'::date" if isinstance(value, time): if value.tzinfo: - return "'%s'::timetz" % (value,) - return "'%s'::time" % value + return f"'{value}'::timetz" + return f"'{value}'::time" if isinstance(value, timedelta): - return "'%s'::interval" % (value,) + return f"'{value}'::interval" if isinstance(value, Uuid): - return "'%s'::uuid" % (value,) + return f"'{value}'::uuid" if isinstance(value, list): # Quote value as an ARRAY constructor. This is better than using # an array literal because it carries the information that this is @@ -852,7 +852,8 @@ def _quote(self, value): if not value: # exception for empty array return "'{}'" q = self._quote - return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),) + v = ','.join(str(q(v)) for v in value) + return f'ARRAY[{v}]' if isinstance(value, tuple): # Quote as a ROW constructor. This is better than using a record # literal because it carries the information that this is a record @@ -860,12 +861,13 @@ def _quote(self, value): # this usable with the IN syntax as well. It is only necessary # when the records has a single column which is not really useful. q = self._quote - return '(%s)' % (','.join(str(q(v)) for v in value),) + v = ','.join(str(q(v)) for v in value) + return f'({v})' try: # noinspection PyUnresolvedReferences value = value.__pg_repr__() except AttributeError: raise InterfaceError( - 'Do not know how to adapt type %s' % (type(value),)) + f'Do not know how to adapt type {type(value)}') if isinstance(value, (tuple, list)): value = self._quote(value) return value @@ -979,10 +981,9 @@ def executemany(self, operation, seq_of_parameters): raise # database provides error message except Error as err: # noinspection PyTypeChecker - raise _db_error( - "Error in '%s': '%s' " % (sql, err), InterfaceError) + raise _db_error(f"Error in '{sql}': '{err}'", InterfaceError) except Exception as err: - raise _op_error("Internal error in '%s': %s" % (sql, err)) + raise _op_error(f"Internal error in '{sql}': {err}") # then initialize result raw count and description if self._src.resulttype == RESULT_DQL: self._description = True # fetch on demand @@ -1049,8 +1050,9 @@ def callproc(self, procname, parameters=None): The procedure may also provide a result set as output. These can be requested through the standard fetch methods of the cursor. """ - n = parameters and len(parameters) or 0 - query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s'])) + n = len(parameters) if parameters else 0 + s = ','.join(n * ['%s']) + query = f'select * from "{procname}"({s})' self.execute(query, parameters) return parameters @@ -1088,7 +1090,7 @@ def copy_from(self, stream, table, if isinstance(stream, (bytes, str)): if not isinstance(stream, input_type): - raise ValueError("The input must be %s" % (type_name,)) + raise ValueError(f"The input must be {type_name}") if not binary_format: if isinstance(stream, str): if not stream.endswith('\n'): @@ -1106,8 +1108,7 @@ def chunks(): for chunk in stream: if not isinstance(chunk, input_type): raise ValueError( - "Input stream must consist of %s" - % (type_name,)) + f"Input stream must consist of {type_name}") if isinstance(chunk, str): if not chunk.endswith('\n'): chunk += '\n' @@ -1144,7 +1145,7 @@ def chunks(): else: table = '.'.join(map( self.connection._cnx.escape_identifier, table.split('.', 1))) - operation = ['copy %s' % (table,)] + operation = [f'copy {table}'] options = [] params = [] if format is not None: @@ -1152,7 +1153,7 @@ def chunks(): raise TypeError("The format option must be be a string") if format not in ('text', 'csv', 'binary'): raise ValueError("Invalid format") - options.append('format %s' % (format,)) + options.append(f'format {format}') if sep is not None: if not isinstance(sep, str): raise TypeError("The sep option must be a string") @@ -1173,10 +1174,11 @@ def chunks(): if not isinstance(columns, str): columns = ','.join(map( self.connection._cnx.escape_identifier, columns)) - operation.append('(%s)' % (columns,)) + operation.append(f'({columns})') operation.append("from stdin") if options: - operation.append('(%s)' % (','.join(options),)) + options = ','.join(options) + operation.append(f'({options})') operation = ' '.join(operation) putdata = self._src.putdata @@ -1226,11 +1228,11 @@ def copy_to(self, stream, table, if table.lower().startswith('select '): if columns: raise ValueError("Columns must be specified in the query") - table = '(%s)' % (table,) + table = f'({table})' else: table = '.'.join(map( self.connection._cnx.escape_identifier, table.split('.', 1))) - operation = ['copy %s' % (table,)] + operation = [f'copy {table}'] options = [] params = [] if format is not None: @@ -1238,7 +1240,7 @@ def copy_to(self, stream, table, raise TypeError("The format option must be a string") if format not in ('text', 'csv', 'binary'): raise ValueError("Invalid format") - options.append('format %s' % (format,)) + options.append(f'format {format}') if sep is not None: if not isinstance(sep, str): raise TypeError("The sep option must be a string") @@ -1267,11 +1269,12 @@ def copy_to(self, stream, table, if not isinstance(columns, str): columns = ','.join(map( self.connection._cnx.escape_identifier, columns)) - operation.append('(%s)' % (columns,)) + operation.append(f'({columns})') operation.append("to stdout") if options: - operation.append('(%s)' % (','.join(options),)) + options = ','.join(options) + operation.append(f'({options})') operation = ' '.join(operation) getdata = self._src.getdata @@ -1553,9 +1556,9 @@ def connect(dsn=None, for kw, value in kwargs: value = str(value) if not value or ' ' in value: - value = "'%s'" % (value.replace( - '\\', '\\\\').replace("'", "\\'")) - dbname.append('%s=%s' % (kw, value)) + value = value.replace('\\', '\\\\').replace("'", "\\'") + value = f"'{value}'" + dbname.append(f'{kw}={value}') dbname = ' '.join(dbname) # open the connection # noinspection PyArgumentList @@ -1734,12 +1737,12 @@ def _quote(cls, s): quote = cls._re_quote.search(s) s = cls._re_escape.sub(r'\\\1', s) if quote: - s = '"%s"' % (s,) + s = f'"{s}"' return s def __str__(self): q = self._quote - return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items()) + return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) class Json: From a7fe116b61fe8d30c7e9595d7ca8820987bbb9ec Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 13:23:17 +0200 Subject: [PATCH 133/194] Modernize string formatting in tests --- tests/dbapi20.py | 212 ++++++++++------------ tests/test_classic.py | 8 +- tests/test_classic_connection.py | 154 ++++++++-------- tests/test_classic_dbwrapper.py | 270 ++++++++++++++--------------- tests/test_classic_functions.py | 31 ++-- tests/test_classic_largeobj.py | 20 +-- tests/test_classic_notification.py | 14 +- tests/test_dbapi20.py | 196 +++++++++++---------- tests/test_dbapi20_copy.py | 24 +-- tests/test_tutorial.py | 2 +- 10 files changed, 452 insertions(+), 479 deletions(-) diff --git a/tests/dbapi20.py b/tests/dbapi20.py index b793fbf2..798bbc49 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -3,6 +3,8 @@ """Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. + +Some modernization of the code has been done by the PyGreSQL team. """ __version__ = '1.15.0' @@ -10,6 +12,8 @@ import unittest import time +from typing import Any, Dict, Tuple + class DatabaseAPI20Test(unittest.TestCase): """Test a database self.driver for DB API 2.0 compatibility. @@ -36,16 +40,16 @@ class mytest(dbapi20.DatabaseAPI20Test): # The self.driver module. This should be the module where the 'connect' # method is to be found - driver = None - connect_args = () # List of arguments to pass to connect - connect_kw_args = {} # Keyword arguments for connect + driver: Any = None + connect_args: Tuple = () # List of arguments to pass to connect + connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables - ddl1 = 'create table %sbooze (name varchar(20))' % (table_prefix,) - ddl2 = 'create table %sbarflys (name varchar(20), drink varchar(30))' % ( - table_prefix,) - xddl1 = 'drop table %sbooze' % (table_prefix,) - xddl2 = 'drop table %sbarflys' % (table_prefix,) + ddl1 = f'create table {table_prefix}booze (name varchar(20))' + ddl2 = (f'create table {table_prefix}barflys (name varchar(20),' + ' drink varchar(30))') + xddl1 = f'drop table {table_prefix}booze' + xddl2 = f'drop table {table_prefix}barflys' insert = 'insert' lowerfunc = 'lower' # Name of stored procedure to convert str to lowercase @@ -155,15 +159,15 @@ def test_ExceptionsAsConnectionAttributes(self): # by default. con = self._connect() drv = self.driver - self.assertTrue(con.Warning is drv.Warning) - self.assertTrue(con.Error is drv.Error) - self.assertTrue(con.InterfaceError is drv.InterfaceError) - self.assertTrue(con.DatabaseError is drv.DatabaseError) - self.assertTrue(con.OperationalError is drv.OperationalError) - self.assertTrue(con.IntegrityError is drv.IntegrityError) - self.assertTrue(con.InternalError is drv.InternalError) - self.assertTrue(con.ProgrammingError is drv.ProgrammingError) - self.assertTrue(con.NotSupportedError is drv.NotSupportedError) + self.assertIs(con.Warning, drv.Warning) + self.assertIs(con.Error, drv.Error) + self.assertIs(con.InterfaceError, drv.InterfaceError) + self.assertIs(con.DatabaseError, drv.DatabaseError) + self.assertIs(con.OperationalError, drv.OperationalError) + self.assertIs(con.IntegrityError, drv.IntegrityError) + self.assertIs(con.InternalError, drv.InternalError) + self.assertIs(con.ProgrammingError, drv.ProgrammingError) + self.assertIs(con.NotSupportedError, drv.NotSupportedError) def test_commit(self): con = self._connect() @@ -200,10 +204,9 @@ def test_cursor_isolation(self): cur1 = con.cursor() cur2 = con.cursor() self.executeDDL1(cur1) - cur1.execute("%s into %sbooze values ('Victoria Bitter')" % ( - self.insert, self.table_prefix - )) - cur2.execute("select name from %sbooze" % self.table_prefix) + cur1.execute(f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") + cur2.execute(f"select name from {self.table_prefix}booze") booze = cur2.fetchall() self.assertEqual(len(booze), 1) self.assertEqual(len(booze[0]), 1) @@ -220,7 +223,7 @@ def test_description(self): cur.description, 'cursor.description should be none after executing a' ' statement that can return no rows (such as DDL)') - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') self.assertEqual( len(cur.description), 1, 'cursor.description describes too many columns') @@ -232,8 +235,8 @@ def test_description(self): 'cursor.description[x][0] must return column name') self.assertEqual( cur.description[0][1], self.driver.STRING, - 'cursor.description[x][1] must return column type. Got %r' - % cur.description[0][1]) + 'cursor.description[x][1] must return column type.' + f' Got: {cur.description[0][1]!r}') # Make sure self.description gets reset self.executeDDL2(cur) @@ -253,14 +256,13 @@ def test_rowcount(self): cur.rowcount, (-1, 0), # Bug #543885 'cursor.rowcount should be -1 or 0 after executing no-result' ' statements') - cur.execute("%s into %sbooze values ('Victoria Bitter')" % ( - self.insert, self.table_prefix - )) + cur.execute(f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") self.assertIn( cur.rowcount, (-1, 1), 'cursor.rowcount should == number or rows inserted, or' ' set to -1 after executing an insert statement') - cur.execute("select name from %sbooze" % self.table_prefix) + cur.execute(f"select name from {self.table_prefix}booze") self.assertIn( cur.rowcount, (-1, 1), 'cursor.rowcount should == number of rows returned, or' @@ -325,47 +327,38 @@ def test_execute(self): def _paraminsert(self, cur): self.executeDDL2(cur) + table_prefix = self.table_prefix + insert = f"{self.insert} into {table_prefix}barflys values" cur.execute( - "%s into %sbarflys values ('Victoria Bitter'," - " 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix)) + f"{insert} ('Victoria Bitter'," + " 'thi%s :may ca%(u)se? troub:1e')") self.assertIn(cur.rowcount, (-1, 1)) if self.driver.paramstyle == 'qmark': cur.execute( - "%s into %sbarflys values (?," - " 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), + f"{insert} (?, 'thi%s :may ca%(u)se? troub:1e')", ("Cooper's",)) elif self.driver.paramstyle == 'numeric': cur.execute( - "%s into %sbarflys values (:1," - " 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), + f"{insert} (:1, 'thi%s :may ca%(u)se? troub:1e')", ("Cooper's",)) elif self.driver.paramstyle == 'named': cur.execute( - "%s into %sbarflys values (:beer," - " 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), + f"{insert} (:beer, 'thi%s :may ca%(u)se? troub:1e')", {'beer': "Cooper's"}) elif self.driver.paramstyle == 'format': cur.execute( - "%s into %sbarflys values (%%s," - " 'thi%%%%s :may ca%%%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), + f"{insert} (%s, 'thi%%s :may ca%%(u)se? troub:1e')", ("Cooper's",)) elif self.driver.paramstyle == 'pyformat': cur.execute( - "%s into %sbarflys values (%%(beer)s," - " 'thi%%%%s :may ca%%%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), + f"{insert} (%(beer)s, 'thi%%s :may ca%%(u)se? troub:1e')", {'beer': "Cooper's"}) else: self.fail('Invalid paramstyle') self.assertIn(cur.rowcount, (-1, 1)) - cur.execute('select name, drink from %sbarflys' % self.table_prefix) + cur.execute(f'select name, drink from {table_prefix}barflys') res = cur.fetchall() self.assertEqual(len(res), 2, 'cursor.fetchall returned too few rows') beers = [res[0][0], res[1][0]] @@ -382,48 +375,38 @@ def _paraminsert(self, cur): self.assertEqual( res[0][1], trouble, 'cursor.fetchall retrieved incorrect data, or data inserted' - ' incorrectly. Got=%s, Expected=%s' % ( - repr(res[0][1]), repr(trouble))) + f' incorrectly. Got: {res[0][1]!r}, Expected: {trouble!r}') self.assertEqual( res[1][1], trouble, 'cursor.fetchall retrieved incorrect data, or data inserted' - ' incorrectly. Got=%s, Expected=%s' % ( - repr(res[1][1]), repr(trouble))) + f' incorrectly. Got: {res[1][1]!r}, Expected: {trouble!r}') def test_executemany(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) + table_prefix = self.table_prefix + insert = f'{self.insert} into {table_prefix}booze values' largs = [("Cooper's",), ("Boag's",)] margs = [{'beer': "Cooper's"}, {'beer': "Boag's"}] if self.driver.paramstyle == 'qmark': - cur.executemany( - '%s into %sbooze values (?)' % ( - self.insert, self.table_prefix), largs) + cur.executemany(f'{insert} (?)', largs) elif self.driver.paramstyle == 'numeric': - cur.executemany( - '%s into %sbooze values (:1)' % ( - self.insert, self.table_prefix), largs) + cur.executemany(f'{insert} (:1)', largs) elif self.driver.paramstyle == 'named': - cur.executemany( - '%s into %sbooze values (:beer)' % ( - self.insert, self.table_prefix), margs) + cur.executemany(f'{insert} (:beer)', margs) elif self.driver.paramstyle == 'format': - cur.executemany( - '%s into %sbooze values (%%s)' % ( - self.insert, self.table_prefix), largs) + cur.executemany(f'{insert} (%s)', largs) elif self.driver.paramstyle == 'pyformat': - cur.executemany( - '%s into %sbooze values (%%(beer)s)' % ( - self.insert, self.table_prefix), margs) + cur.executemany(f'{insert} (%(beer)s)', margs) else: self.fail('Unknown paramstyle') self.assertIn( cur.rowcount, (-1, 2), 'insert using cursor.executemany set cursor.rowcount to' - ' incorrect value %r' % cur.rowcount) - cur.execute('select name from %sbooze' % self.table_prefix) + f' incorrect value {cur.rowcount!r}') + cur.execute(f'select name from {table_prefix}booze') res = cur.fetchall() self.assertEqual( len(res), 2, @@ -449,7 +432,7 @@ def test_fetchone(self): self.executeDDL1(cur) self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') self.assertIsNone( cur.fetchone(), 'cursor.fetchone should return None if a query retrieves' @@ -458,12 +441,12 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called after # executing a query that cannot return rows - cur.execute("%s into %sbooze values ('Victoria Bitter')" % ( - self.insert, self.table_prefix - )) + cur.execute( + f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') r = cur.fetchone() self.assertEqual( len(r), 1, @@ -490,8 +473,7 @@ def test_fetchone(self): def _populate(self): """Return a list of SQL commands to setup the DB for fetching tests.""" populate = [ - "%s into %sbooze values ('%s')" % ( - self.insert, self.table_prefix, s) + f"{self.insert} into {self.table_prefix}booze values ('{s}')" for s in self.samples] return populate @@ -508,7 +490,7 @@ def test_fetchmany(self): for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') r = cur.fetchmany() self.assertEqual( len(r), 1, @@ -532,7 +514,7 @@ def test_fetchmany(self): # Same as above, using cursor.arraysize cur.arraysize = 4 - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') r = cur.fetchmany() # Should get 4 rows self.assertEqual( len(r), 4, @@ -544,7 +526,7 @@ def test_fetchmany(self): self.assertIn(cur.rowcount, (-1, 6)) cur.arraysize = 6 - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') rows = cur.fetchmany() # Should get all rows self.assertIn(cur.rowcount, (-1, 6)) self.assertEqual(len(rows), 6) @@ -566,7 +548,7 @@ def test_fetchmany(self): self.assertIn(cur.rowcount, (-1, 6)) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}barflys') r = cur.fetchmany() # Should get empty sequence self.assertEqual( len(r), 0, @@ -594,7 +576,7 @@ def test_fetchall(self): # after executing a a statement that cannot return rows self.assertRaises(self.driver.Error, cur.fetchall) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') rows = cur.fetchall() self.assertIn(cur.rowcount, (-1, len(self.samples))) self.assertEqual( @@ -613,7 +595,7 @@ def test_fetchall(self): self.assertIn(cur.rowcount, (-1, len(self.samples))) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}barflys') rows = cur.fetchall() self.assertIn(cur.rowcount, (-1, 0)) self.assertEqual( @@ -632,7 +614,7 @@ def test_mixedfetch(self): for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') rows1 = cur.fetchone() rows23 = cur.fetchmany(2) rows4 = cur.fetchone() @@ -676,51 +658,45 @@ def help_nextset_setUp(self, cur): def help_nextset_tearDown(self, cur): """Clean up after nextset test. - If cleaning up is needed after nextSetTest. + If cleaning up is needed after test_nextset. """ raise NotImplementedError('Helper not implemented') # cur.execute("drop procedure deleteme") - # example test implementation only def test_nextset(self): - con = self._connect() - try: - cur = con.cursor() - if not hasattr(cur, 'nextset'): - return - - try: - self.executeDDL1(cur) - for sql in self._populate(): - cur.execute(sql) - - self.help_nextset_setUp(cur) - - cur.callproc('deleteme') - number_of_rows = cur.fetchone() - self.assertEqual(number_of_rows[0], len(self.samples)) - self.assertTrue(cur.nextset()) - names = cur.fetchall() - self.assertEqual(len(names), len(self.samples)) - s = cur.nextset() - self.assertIsNone(s, 'No more return sets, should return None') - finally: - self.help_nextset_tearDown(cur) - - finally: - con.close() - - # noinspection PyRedeclaration - def test_nextset(self): # noqa: F811 + """Test the nextset functionality.""" raise NotImplementedError('Drivers need to override this test') + # example test implementation only: + # con = self._connect() + # try: + # cur = con.cursor() + # if not hasattr(cur, 'nextset'): + # return + # try: + # self.executeDDL1(cur) + # for sql in self._populate(): + # cur.execute(sql) + # self.help_nextset_setUp(cur) + # cur.callproc('deleteme') + # number_of_rows = cur.fetchone() + # self.assertEqual(number_of_rows[0], len(self.samples)) + # self.assertTrue(cur.nextset()) + # names = cur.fetchall() + # self.assertEqual(len(names), len(self.samples)) + # self.assertIsNone( + # cur.nextset(), 'No more return sets, should return None') + # finally: + # self.help_nextset_tearDown(cur) + # finally: + # con.close() def test_arraysize(self): # Not much here - rest of the tests for this are in test_fetchmany con = self._connect() try: cur = con.cursor() - self.assertTrue( - hasattr(cur, 'arraysize'), 'cursor.arraysize must be defined') + self.assertTrue(hasattr(cur, 'arraysize'), + 'cursor.arraysize must be defined') finally: con.close() @@ -756,9 +732,9 @@ def test_None(self): # inserting NULL to the second column, because some drivers might # need the first one to be primary key, which means it needs # to have a non-NULL value - cur.execute("%s into %sbarflys values ('a', NULL)" % ( - self.insert, self.table_prefix)) - cur.execute('select drink from %sbarflys' % self.table_prefix) + cur.execute(f"{self.insert} into {self.table_prefix}barflys" + " values ('a', NULL)") + cur.execute(f'select drink from {self.table_prefix}barflys') r = cur.fetchall() self.assertEqual(len(r), 1) self.assertEqual(len(r[0]), 1) diff --git a/tests/test_classic.py b/tests/test_classic.py index 799cb6c7..6319d5d5 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -48,11 +48,11 @@ def setUpClass(cls): except Exception: pass try: - db.query("DROP TABLE %s._test_schema" % (t,)) + db.query(f"DROP TABLE {t}._test_schema") except Exception: pass - db.query("CREATE TABLE %s._test_schema" - " (%s int PRIMARY KEY)" % (t, t)) + db.query(f"CREATE TABLE {t}._test_schema" + f" ({t} int PRIMARY KEY)") db.close() def setUp(self): @@ -60,7 +60,7 @@ def setUp(self): db = open_db() db.query("TRUNCATE TABLE _test_schema") for t in ('_test1', '_test2'): - db.query("TRUNCATE TABLE %s._test_schema" % t) + db.query(f"TRUNCATE TABLE {t}._test_schema") db.close() def test_invalid_name(self): diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index c456b4ec..f7ca2a46 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -54,7 +54,7 @@ def testCanConnect(self): connection = connect() rc = connection.poll() except pg.Error as error: - self.fail('Cannot connect to database %s:\n%s' % (dbname, error)) + self.fail(f'Cannot connect to database {dbname}:\n{error}') self.assertEqual(rc, pg.POLLING_OK) self.assertIs(connection.is_non_blocking(), False) connection.set_non_blocking(True) @@ -74,7 +74,7 @@ def testCanConnectNoWait(self): while rc not in (pg.POLLING_OK, pg.POLLING_FAILED): rc = connection.poll() except pg.Error as error: - self.fail('Cannot connect to database %s:\n%s' % (dbname, error)) + self.fail(f'Cannot connect to database {dbname}:\n{error}') self.assertEqual(rc, pg.POLLING_OK) self.assertIs(connection.is_non_blocking(), False) connection.set_non_blocking(True) @@ -310,7 +310,7 @@ def testMethodReset(self): encoding = query('show client_encoding').getresult()[0][0].upper() changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8' self.assertNotEqual(encoding, changed_encoding) - self.connection.query("set client_encoding=%s" % changed_encoding) + self.connection.query(f"set client_encoding={changed_encoding}") new_encoding = query('show client_encoding').getresult()[0][0].upper() self.assertEqual(new_encoding, changed_encoding) self.connection.reset() @@ -459,7 +459,7 @@ def testGetresultDecimal(self): def testGetresultString(self): result = 'Hello, world!' - q = "select '%s'" % result + q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -503,7 +503,7 @@ def testDictresultDecimal(self): def testDictresultString(self): result = 'Hello, world!' - q = "select '%s' as greeting" % result + q = f"select '{result}' as greeting" v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -699,7 +699,7 @@ def testFieldInfoName(self): for field_num, info in enumerate(result): field_name = info[0] if field_num > 0: - field_name = '"%s"' % field_name + field_name = f'"{field_name}"' r = f(field_name) self.assertIsInstance(r, tuple) self.assertEqual(len(r), 4) @@ -841,28 +841,27 @@ def tearDown(self): self.c.close() def testGetresulAscii(self): - result = u'Hello, world!' - q = u"select '%s'" % result + result = 'Hello, world!' + q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) def testDictresulAscii(self): - result = u'Hello, world!' - q = u"select '%s' as greeting" % result + result = 'Hello, world!' + q = f"select '{result}' as greeting" v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) def testGetresultUtf8(self): - result = u'Hello, wörld & мир!' - q = u"select '%s'" % result + result = 'Hello, wörld & мир!' + q = f"select '{result}'" # pass the query as unicode try: v = self.c.query(q).getresult()[0][0] except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") - v = None self.assertIsInstance(v, str) self.assertEqual(v, result) q = q.encode('utf8') @@ -872,13 +871,12 @@ def testGetresultUtf8(self): self.assertEqual(v, result) def testDictresultUtf8(self): - result = u'Hello, wörld & мир!' - q = u"select '%s' as greeting" % result + result = 'Hello, wörld & мир!' + q = f"select '{result}' as greeting" try: v = self.c.query(q).dictresult()[0]['greeting'] except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") - v = None self.assertIsInstance(v, str) self.assertEqual(v, result) q = q.encode('utf8') @@ -891,8 +889,8 @@ def testGetresultLatin1(self): self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") - result = u'Hello, wörld!' - q = u"select '%s'" % result + result = 'Hello, wörld!' + q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -906,8 +904,8 @@ def testDictresultLatin1(self): self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") - result = u'Hello, wörld!' - q = u"select '%s' as greeting" % result + result = 'Hello, wörld!' + q = f"select '{result}' as greeting" v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -921,8 +919,8 @@ def testGetresultCyrillic(self): self.c.query('set client_encoding=iso_8859_5') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") - result = u'Hello, мир!' - q = u"select '%s'" % result + result = 'Hello, мир!' + q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -936,8 +934,8 @@ def testDictresultCyrillic(self): self.c.query('set client_encoding=iso_8859_5') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") - result = u'Hello, мир!' - q = u"select '%s' as greeting" % result + result = 'Hello, мир!' + q = f"select '{result}' as greeting" v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -951,8 +949,8 @@ def testGetresultLatin9(self): self.c.query('set client_encoding=latin9') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin9") - result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' - q = u"select '%s'" % result + result = 'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' + q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -966,8 +964,8 @@ def testDictresultLatin9(self): self.c.query('set client_encoding=latin9') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin9") - result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' - q = u"select '%s' as menu" % result + result = 'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' + q = f"select '{result}' as menu" v = self.c.query(q).dictresult()[0]['menu'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -1092,7 +1090,7 @@ def testQueryWithUnicodeParams(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") self.assertEqual( - query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult(), + query("select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult(), [('Hello, wörld!',)]) def testQueryWithUnicodeParamsLatin1(self): @@ -1103,22 +1101,22 @@ def testQueryWithUnicodeParamsLatin1(self): query("select 'wörld'").getresult()[0][0], 'wörld') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") - r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() + r = query("select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult() self.assertEqual(r, [('Hello, wörld!',)]) self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'мир')) + ('Hello', 'мир')) query('set client_encoding=iso_8859_1') r = query( - "select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() + "select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult() self.assertEqual(r, [('Hello, wörld!',)]) self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'мир')) + ('Hello', 'мир')) query('set client_encoding=sql_ascii') self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'wörld')) + ('Hello', 'wörld')) def testQueryWithUnicodeParamsCyrillic(self): query = self.c.query @@ -1130,14 +1128,14 @@ def testQueryWithUnicodeParamsCyrillic(self): self.skipTest("database does not support cyrillic") self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'wörld')) + ('Hello', 'wörld')) r = query( - "select $1||', '||$2||'!'", ('Hello', u'мир')).getresult() + "select $1||', '||$2||'!'", ('Hello', 'мир')).getresult() self.assertEqual(r, [('Hello, мир!',)]) query('set client_encoding=sql_ascii') self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'мир!')) + ('Hello', 'мир!')) def testQueryWithMixedParams(self): self.assertEqual( @@ -1264,7 +1262,7 @@ def tearDown(self): self.c.close() def assert_proper_cast(self, value, pgtype, pytype): - q = 'select $1::%s' % (pgtype,) + q = f'select $1::{pgtype}' try: r = self.c.query(q, (value,)).getresult()[0][0] except pg.ProgrammingError as e: @@ -1275,8 +1273,8 @@ def assert_proper_cast(self, value, pgtype, pytype): self.assertIsInstance(r, pytype) if isinstance(value, str): if not value or ' ' in value or '{' in value: - value = '"%s"' % value - value = '{%s}' % value + value = f'"{value}"' + value = f'{{{value}}}' r = self.c.query(q + '[]', (value,)).getresult()[0][0] if pgtype.startswith(('date', 'time', 'interval')): # arrays of these are casted by the DB wrapper only @@ -2009,11 +2007,11 @@ def testInserttableByteValues(self): except pg.DataError: self.skipTest("database does not support utf8") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' + c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "käse сыр pont-l'évêque") row_bytes = tuple( s.encode('utf-8') if isinstance(s, str) else s for s in row_unicode) @@ -2028,11 +2026,11 @@ def testInserttableUnicodeUtf8(self): except pg.DataError: self.skipTest("database does not support utf8") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' + c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "käse сыр pont-l'évêque") data = [row_unicode] * 2 self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) @@ -2044,16 +2042,16 @@ def testInserttableUnicodeLatin1(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' + c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] # cannot encode € sign with latin1 encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) row_unicode = tuple( - s.replace(u'€', u'¥') if isinstance(s, str) else s + s.replace('€', '¥') if isinstance(s, str) else s for s in row_unicode) data = [row_unicode] * 2 self.c.inserttable('test', data) @@ -2067,11 +2065,11 @@ def testInserttableUnicodeLatin9(self): self.skipTest("database does not support latin9") return # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' + c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] * 2 self.c.inserttable('test', data) self.assertEqual(self.get_back('latin9'), data) @@ -2079,11 +2077,11 @@ def testInserttableUnicodeLatin9(self): def testInserttableNoEncoding(self): self.c.query("set client_encoding=sql_ascii") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' + c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] # cannot encode non-ascii unicode without a specific encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) @@ -2174,7 +2172,7 @@ def testPutline(self): query("copy test from stdin") try: for i, v in data: - putline("%d\t%s\n" % (i, v)) + putline(f"{i}\t{v}\n") finally: self.c.endcopy() r = query("select * from test").getresult() @@ -2189,7 +2187,7 @@ def testPutlineBytesAndUnicode(self): self.skipTest('database does not support utf8') query("copy test from stdin") try: - putline(u"47\tkäse\n".encode('utf8')) + putline("47\tkäse\n".encode('utf8')) putline("35\twürstel\n") finally: self.c.endcopy() @@ -2208,7 +2206,7 @@ def testGetline(self): v = getline() if i < n: # noinspection PyStringFormat - self.assertEqual(v, '%d\t%s' % data[i]) + self.assertEqual(v, '{}\t{}'.format(*data[i])) elif i == n: self.assertIsNone(v) finally: @@ -2224,7 +2222,7 @@ def testGetlineBytesAndUnicode(self): query("select 'käse+würstel'") except (pg.DataError, pg.NotSupportedError): self.skipTest('database does not support utf8') - data = [(54, u'käse'.encode('utf8')), (73, u'würstel')] + data = [(54, 'käse'.encode('utf8')), (73, 'würstel')] self.c.inserttable('test', data) query("copy test to stdout") try: @@ -2405,7 +2403,7 @@ def testSetDecimalPoint(self): # first try with English localization (using the point) for lc in en_locales: try: - query("set lc_monetary='%s'" % lc) + query(f"set lc_monetary='{lc}'") except pg.DataError: pass else: @@ -2456,7 +2454,7 @@ def testSetDecimalPoint(self): # then try with German localization (using the comma) for lc in de_locales: try: - query("set lc_monetary='%s'" % lc) + query(f"set lc_monetary='{lc}'") except pg.DataError: pass else: @@ -2714,15 +2712,15 @@ def testEscapeString(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, str) - self.assertEqual(r, u'plain') - r = f(u"das is' käse".encode('utf-8')) + self.assertEqual(r, 'plain') + r = f("das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is'' käse".encode('utf-8')) - r = f(u"that's cheesy") + self.assertEqual(r, "das is'' käse".encode('utf-8')) + r = f("that's cheesy") self.assertIsInstance(r, str) - self.assertEqual(r, u"that''s cheesy") + self.assertEqual(r, "that''s cheesy") r = f(r"It's bad to have a \ inside.") self.assertEqual(r, r"It''s bad to have a \\ inside.") @@ -2732,15 +2730,15 @@ def testEscapeBytea(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, str) - self.assertEqual(r, u'plain') - r = f(u"das is' käse".encode('utf-8')) + self.assertEqual(r, 'plain') + r = f("das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, b"das is'' k\\\\303\\\\244se") - r = f(u"that's cheesy") + r = f("that's cheesy") self.assertIsInstance(r, str) - self.assertEqual(r, u"that''s cheesy") + self.assertEqual(r, "that''s cheesy") r = f(b'O\x00ps\xff!') self.assertEqual(r, b'O\\\\000ps\\\\377!') diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 3d372ad3..79c962a4 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -466,15 +466,15 @@ def createTable(self, table, definition, temporary=True, oids=None, values=None): query = self.db.query if '"' not in table or '.' in table: - table = '"%s"' % table + table = f'"{table}"' if not temporary: - q = 'drop table if exists %s cascade' % table + q = f'drop table if exists {table} cascade' query(q) self.addCleanup(query, q) temporary = 'temporary table' if temporary else 'table' as_query = definition.startswith(('as ', 'AS ')) if not as_query and not definition.startswith('('): - definition = '(%s)' % definition + definition = f'({definition})' with_oids = 'with oids' if oids else ( 'without oids' if self.oids else '') q = ['create', temporary, table] @@ -488,8 +488,8 @@ def createTable(self, table, definition, for params in values: if not isinstance(params, (list, tuple)): params = [params] - values = ', '.join('$%d' % (n + 1) for n in range(len(params))) - q = "insert into %s values (%s)" % (table, values) + values = ', '.join(f'${n + 1}' for n in range(len(params))) + q = f"insert into {table} values ({values})" query(q, params) def testClassName(self): @@ -504,15 +504,15 @@ def testEscapeLiteral(self): r = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b"'plain'") - r = f(u"plain") + r = f("plain") self.assertIsInstance(r, str) - self.assertEqual(r, u"'plain'") - r = f(u"that's käse".encode('utf-8')) + self.assertEqual(r, "'plain'") + r = f("that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) - self.assertEqual(r, u"'that''s käse'".encode('utf-8')) - r = f(u"that's käse") + self.assertEqual(r, "'that''s käse'".encode('utf-8')) + r = f("that's käse") self.assertIsInstance(r, str) - self.assertEqual(r, u"'that''s käse'") + self.assertEqual(r, "'that''s käse'") self.assertEqual(f(r"It's fine to have a \ inside."), r" E'It''s fine to have a \\ inside.'") self.assertEqual(f('No "quotes" must be escaped.'), @@ -523,15 +523,15 @@ def testEscapeIdentifier(self): r = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b'"plain"') - r = f(u"plain") + r = f("plain") self.assertIsInstance(r, str) - self.assertEqual(r, u'"plain"') - r = f(u"that's käse".encode('utf-8')) + self.assertEqual(r, '"plain"') + r = f("that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) - self.assertEqual(r, u'"that\'s käse"'.encode('utf-8')) - r = f(u"that's käse") + self.assertEqual(r, '"that\'s käse"'.encode('utf-8')) + r = f("that's käse") self.assertIsInstance(r, str) - self.assertEqual(r, u'"that\'s käse"') + self.assertEqual(r, '"that\'s käse"') self.assertEqual(f(r"It's fine to have a \ inside."), '"It\'s fine to have a \\ inside."') self.assertEqual(f('All "quotes" must be escaped.'), @@ -542,15 +542,15 @@ def testEscapeString(self): r = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b"plain") - r = f(u"plain") + r = f("plain") self.assertIsInstance(r, str) - self.assertEqual(r, u"plain") - r = f(u"that's käse".encode('utf-8')) + self.assertEqual(r, "plain") + r = f("that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) - self.assertEqual(r, u"that''s käse".encode('utf-8')) - r = f(u"that's käse") + self.assertEqual(r, "that''s käse".encode('utf-8')) + r = f("that's käse") self.assertIsInstance(r, str) - self.assertEqual(r, u"that''s käse") + self.assertEqual(r, "that''s käse") self.assertEqual(f(r"It's fine to have a \ inside."), r"It''s fine to have a \ inside.") @@ -561,15 +561,15 @@ def testEscapeBytea(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x706c61696e') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, str) - self.assertEqual(r, u'\\x706c61696e') - r = f(u"das is' käse".encode('utf-8')) + self.assertEqual(r, '\\x706c61696e') + r = f("das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x64617320697327206bc3a47365') - r = f(u"das is' käse") + r = f("das is' käse") self.assertIsInstance(r, str) - self.assertEqual(r, u'\\x64617320697327206bc3a47365') + self.assertEqual(r, '\\x64617320697327206bc3a47365') self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21') def testUnescapeBytea(self): @@ -577,15 +577,15 @@ def testUnescapeBytea(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(b"das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf8')) - r = f(u"das is' k\\303\\244se") + self.assertEqual(r, "das is' käse".encode('utf8')) + r = f("das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf8')) + self.assertEqual(r, "das is' käse".encode('utf8')) self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!') self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e') self.assertEqual(f(r'\\x746861742773206be47365'), @@ -848,7 +848,7 @@ def testCreateTable(self): values = [(2, "World!"), (1, "Hello")] self.createTable(table, "n smallint, t varchar", temporary=True, oids=False, values=values) - r = self.db.query('select t from "%s" order by n' % table).getresult() + r = self.db.query(f'select t from "{table}" order by n').getresult() r = ', '.join(row[0] for row in r) self.assertEqual(r, "Hello, World!") @@ -859,10 +859,10 @@ def testCreateTableWithOids(self): values = [(2, "World!"), (1, "Hello")] self.createTable(table, "n smallint, t varchar", temporary=True, oids=True, values=values) - r = self.db.query('select t from "%s" order by n' % table).getresult() + r = self.db.query(f'select t from "{table}" order by n').getresult() r = ', '.join(row[0] for row in r) self.assertEqual(r, "Hello, World!") - r = self.db.query('select oid from "%s" limit 1' % table).getresult() + r = self.db.query(f'select oid from "{table}" limit 1').getresult() self.assertIsInstance(r[0][0], int) def testQuery(self): @@ -1131,56 +1131,56 @@ def testPkey(self): pkey = self.db.pkey self.assertRaises(KeyError, pkey, 'test') for t in ('pkeytest', 'primary key test'): - self.createTable('%s0' % t, 'a smallint') - self.createTable('%s1' % t, 'b smallint primary key') - self.createTable('%s2' % t, 'c smallint, d smallint primary key') + self.createTable(f'{t}0', 'a smallint') + self.createTable(f'{t}1', 'b smallint primary key') + self.createTable(f'{t}2', 'c smallint, d smallint primary key') self.createTable( - '%s3' % t, + f'{t}3', 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (f, h)') self.createTable( - '%s4' % t, + f'{t}4', 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (h, f)') self.createTable( - '%s5' % t, 'more_than_one_letter varchar primary key') + f'{t}5', 'more_than_one_letter varchar primary key') self.createTable( - '%s6' % t, '"with space" date primary key') + f'{t}6', '"with space" date primary key') self.createTable( - '%s7' % t, + f'{t}7', 'a_very_long_column_name varchar, "with space" date, "42" int,' ' primary key (a_very_long_column_name, "with space", "42")') - self.assertRaises(KeyError, pkey, '%s0' % t) - self.assertEqual(pkey('%s1' % t), 'b') - self.assertEqual(pkey('%s1' % t, True), ('b',)) - self.assertEqual(pkey('%s1' % t, composite=False), 'b') - self.assertEqual(pkey('%s1' % t, composite=True), ('b',)) - self.assertEqual(pkey('%s2' % t), 'd') - self.assertEqual(pkey('%s2' % t, composite=True), ('d',)) - r = pkey('%s3' % t) + self.assertRaises(KeyError, pkey, f'{t}0') + self.assertEqual(pkey(f'{t}1'), 'b') + self.assertEqual(pkey(f'{t}1', True), ('b',)) + self.assertEqual(pkey(f'{t}1', composite=False), 'b') + self.assertEqual(pkey(f'{t}1', composite=True), ('b',)) + self.assertEqual(pkey(f'{t}2'), 'd') + self.assertEqual(pkey(f'{t}2', composite=True), ('d',)) + r = pkey(f'{t}3') self.assertIsInstance(r, tuple) self.assertEqual(r, ('f', 'h')) - r = pkey('%s3' % t, composite=False) + r = pkey(f'{t}3', composite=False) self.assertIsInstance(r, tuple) self.assertEqual(r, ('f', 'h')) - r = pkey('%s4' % t) + r = pkey(f'{t}4') self.assertIsInstance(r, tuple) self.assertEqual(r, ('h', 'f')) - self.assertEqual(pkey('%s5' % t), 'more_than_one_letter') - self.assertEqual(pkey('%s6' % t), 'with space') - r = pkey('%s7' % t) + self.assertEqual(pkey(f'{t}5'), 'more_than_one_letter') + self.assertEqual(pkey(f'{t}6'), 'with space') + r = pkey(f'{t}7') self.assertIsInstance(r, tuple) self.assertEqual(r, ( 'a_very_long_column_name', 'with space', '42')) # a newly added primary key will be detected - query('alter table "%s0" add primary key (a)' % t) - self.assertEqual(pkey('%s0' % t), 'a') + query(f'alter table "{t}0" add primary key (a)') + self.assertEqual(pkey(f'{t}0'), 'a') # a changed primary key will not be detected, # indicating that the internal cache is operating - query('alter table "%s1" rename column b to x' % t) - self.assertEqual(pkey('%s1' % t), 'b') + query(f'alter table "{t}1" rename column b to x') + self.assertEqual(pkey(f'{t}1'), 'b') # we get the changed primary key when the cache is flushed - self.assertEqual(pkey('%s1' % t, flush=True), 'x') + self.assertEqual(pkey(f'{t}1', flush=True), 'x') def testGetDatabases(self): databases = self.db.get_databases() @@ -1197,7 +1197,7 @@ def testGetTables(self): 'averyveryveryveryveryveryveryreallyreallylongtablename', 'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z') for t in tables: - self.db.query('drop table if exists "%s" cascade' % t) + self.db.query(f'drop table if exists "{t}" cascade') before_tables = get_tables() self.assertIsInstance(before_tables, list) for t in before_tables: @@ -1212,8 +1212,8 @@ def testGetTables(self): self.createTable(t, 'as select 0', temporary=False) current_tables = get_tables() new_tables = [t for t in current_tables if t not in before_tables] - expected_new_tables = ['public.%s' % ( - '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables] + expected_new_tables = ['public.' + ( + f'"{t}"' if ' ' in t or t != t.lower() else t) for t in tables] self.assertEqual(new_tables, expected_new_tables) self.doCleanups() after_tables = get_tables() @@ -1513,8 +1513,8 @@ def testGetGeneratedIsCached(self): table = 'test_get_generated_2' self.createTable(table, 'i int primary key') self.assertFalse(get_generated(table)) - query('alter table %s alter column i' - ' add generated always as identity' % table) + query(f'alter table {table} alter column i' + ' add generated always as identity') self.assertFalse(get_generated(table)) self.assertEqual(get_generated(table, flush=True), {'i'}) @@ -1573,8 +1573,8 @@ def testGet(self): r = get(table, s, ('n', 't')) self.assertIs(r, s) self.assertEqual(r, dict(n=1, t='x')) - query('alter table "%s" alter n set not null' % table) - query('alter table "%s" add primary key (n)' % table) + query(f'alter table "{table}" alter n set not null') + query(f'alter table "{table}" add primary key (n)') r = get(table, 2) self.assertIsInstance(r, dict) self.assertEqual(r, dict(n=2, t='y')) @@ -1605,7 +1605,7 @@ def testGetWithOids(self): self.assertRaises(pg.ProgrammingError, get, table, 2) self.assertRaises(KeyError, get, table, {}, 'oid') r = get(table, 2, 'n') - qoid = 'oid(%s)' % table + qoid = f'oid({table})' self.assertIn(qoid, r) oid = r[qoid] self.assertIsInstance(oid, int) @@ -1632,8 +1632,8 @@ def testGetWithOids(self): self.assertEqual(get(table, r, 'n')['t'], 'z') self.assertEqual(get(table, 1, 'n')['t'], 'x') self.assertEqual(get(table, r, 'oid')['t'], 'z') - query('alter table "%s" alter n set not null' % table) - query('alter table "%s" add primary key (n)' % table) + query(f'alter table "{table}" alter n set not null') + query(f'alter table "{table}" add primary key (n)') self.assertEqual(get(table, 3)['t'], 'z') self.assertEqual(get(table, 1)['t'], 'x') self.assertEqual(get(table, 2)['t'], 'y') @@ -1836,10 +1836,10 @@ def testInsert(self): ts = datetime.strptime(ts, '%Y-%m-%d %H:%M:%S') expect['ts'] = ts self.assertEqual(data, expect) - data = query('select * from "%s"' % table).dictresult()[0] + data = query(f'select * from "{table}"').dictresult()[0] data = dict(item for item in data.items() if item[0] in expect) self.assertEqual(data, expect) - query('truncate table "%s"' % table) + query(f'truncate table "{table}"') def testInsertWithOids(self): if not self.oids: @@ -1923,7 +1923,7 @@ def testInsertWithQuotedNames(self): self.assertEqual(r['Prime!'], 11) self.assertEqual(r['much space'], 2002) self.assertEqual(r['Questions?'], 'What?') - r = query('select * from "%s" limit 2' % table).dictresult() + r = query(f'select * from "{table}" limit 2').dictresult() self.assertEqual(len(r), 1) r = r[0] self.assertEqual(r['Prime!'], 11) @@ -1995,7 +1995,7 @@ def testUpdate(self): r['t'] = 'u' s = update(table, r) self.assertEqual(s, r) - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'u') @@ -2091,7 +2091,7 @@ def testUpdateWithoutOid(self): r['t'] = 'u' s = update(table, r) self.assertEqual(s, r) - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'u') @@ -2107,20 +2107,20 @@ def testUpdateWithCompositeKey(self): self.assertIs(r, s) self.assertEqual(r['n'], 2) self.assertEqual(r['t'], 'd') - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'd') s.update(dict(n=4, t='e')) r = update(table, s) self.assertEqual(r['n'], 4) self.assertEqual(r['t'], 'e') - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'd') - q = 'select t from "%s" where n=4' % table + q = f'select t from "{table}" where n=4' r = query(q).getresult() self.assertEqual(len(r), 0) - query('drop table "%s"' % table) + query(f'drop table "{table}"') table = 'update_test_table_2' self.createTable(table, 'n integer, m integer, t text, primary key (n, m)', @@ -2129,7 +2129,7 @@ def testUpdateWithCompositeKey(self): self.assertRaises(KeyError, update, table, dict(n=2, t='b')) self.assertEqual(update(table, dict(n=2, m=2, t='x'))['t'], 'x') - q = 'select t from "%s" where n=2 order by m' % table + q = f'select t from "{table}" where n=2 order by m' r = [r[0] for r in query(q).getresult()] self.assertEqual(r, ['c', 'x']) @@ -2146,7 +2146,7 @@ def testUpdateWithQuotedNames(self): self.assertEqual(r['Prime!'], 13) self.assertEqual(r['much space'], 7007) self.assertEqual(r['Questions?'], 'When?') - r = query('select * from "%s" limit 2' % table).dictresult() + r = query(f'select * from "{table}" limit 2').dictresult() self.assertEqual(len(r), 1) r = r[0] self.assertEqual(r['Prime!'], 13) @@ -2173,7 +2173,7 @@ def testUpdateWithGeneratedColumns(self): self.createTable(table, table_def) i, d = 35, 1001 j = i + 7 - r = query('insert into %s (i, d) values (%d, %d)' % (table, i, d)) + r = query(f'insert into {table} (i, d) values ({i}, {d})') self.assertEqual(r, '1') r = get(table, d) self.assertIsInstance(r, dict) @@ -2202,7 +2202,7 @@ def testUpsert(self): self.assertIs(r, s) self.assertEqual(r['n'], 2) self.assertEqual(r['t'], 'y') - q = 'select n, t from "%s" order by n limit 3' % table + q = f'select n, t from "{table}" order by n limit 3' r = query(q).getresult() self.assertEqual(r, [(1, 'x'), (2, 'y')]) s.update(t='z') @@ -2357,7 +2357,7 @@ def testUpsertWithCompositeKey(self): self.assertEqual(r['n'], 1) self.assertEqual(r['m'], 3) self.assertEqual(r['t'], 'y') - q = 'select n, m, t from "%s" order by n, m limit 3' % table + q = f'select n, m, t from "{table}" order by n, m limit 3' r = query(q).getresult() self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')]) s.update(t='z') @@ -2413,7 +2413,7 @@ def testUpsertWithQuotedNames(self): self.assertEqual(r['Prime!'], 31) self.assertEqual(r['much space'], 9009) self.assertEqual(r['Questions?'], 'Yes.') - q = 'select * from "%s" limit 2' % table + q = f'select * from "{table}" limit 2' r = query(q).getresult() self.assertEqual(r, [(31, 9009, 'Yes.')]) s.update({'Questions?': 'No.'}) @@ -2506,7 +2506,7 @@ def testDelete(self): self.assertEqual(s, 1) s = delete(table, r) self.assertEqual(s, 0) - r = query('select * from "%s"' % table).dictresult() + r = query(f'select * from "{table}"').dictresult() self.assertEqual(len(r), 1) r = r[0] result = {'n': 2, 't': 'y'} @@ -2574,7 +2574,7 @@ def testDeleteWithOids(self): self.assertIn('m', self.db.get_attnames('test_table', flush=True)) self.assertEqual('n', self.db.pkey('test_table', flush=True)) for i in range(5): - query("insert into test_table values (%d, %d)" % (i + 1, i + 2)) + query(f"insert into test_table values ({i + 1}, {i + 2})") s = dict(m=2) self.assertRaises(KeyError, delete, 'test_table', s) s = dict(m=2, oid=oid) @@ -2625,10 +2625,10 @@ def testDeleteWithCompositeKey(self): values=enumerate('abc', start=1)) self.assertRaises(KeyError, self.db.delete, table, dict(t='b')) self.assertEqual(self.db.delete(table, dict(n=2)), 1) - r = query('select t from "%s" where n=2' % table).getresult() + r = query(f'select t from "{table}" where n=2').getresult() self.assertEqual(r, []) self.assertEqual(self.db.delete(table, dict(n=2)), 0) - r = query('select t from "%s" where n=3' % table).getresult()[0][0] + r = query(f'select t from "{table}" where n=3').getresult()[0][0] self.assertEqual(r, 'c') table = 'delete_test_table_2' self.createTable( @@ -2637,16 +2637,16 @@ def testDeleteWithCompositeKey(self): for n in range(3) for m in range(2)]) self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b')) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1) - r = [r[0] for r in query('select t from "%s" where n=2' - ' order by m' % table).getresult()] + r = [r[0] for r in query(f'select t from "{table}" where n=2' + ' order by m').getresult()] self.assertEqual(r, ['c']) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0) - r = [r[0] for r in query('select t from "%s" where n=3' - ' order by m' % table).getresult()] + r = [r[0] for r in query(f'select t from "{table}" where n=3' + ' order by m').getresult()] self.assertEqual(r, ['e', 'f']) self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1) - r = [r[0] for r in query('select t from "%s" where n=3' - ' order by m' % table).getresult()] + r = [r[0] for r in query(f'select t from "{table}" where n=3' + f' order by m').getresult()] self.assertEqual(r, ['f']) def testDeleteWithQuotedNames(self): @@ -2660,12 +2660,12 @@ def testDeleteWithQuotedNames(self): r = {'Prime!': 17} r = delete(table, r) self.assertEqual(r, 0) - r = query('select count(*) from "%s"' % table).getresult() + r = query(f'select count(*) from "{table}"').getresult() self.assertEqual(r[0][0], 1) r = {'Prime!': 19} r = delete(table, r) self.assertEqual(r, 1) - r = query('select count(*) from "%s"' % table).getresult() + r = query(f'select count(*) from "{table}"').getresult() self.assertEqual(r[0][0], 0) def testDeleteReferenced(self): @@ -2718,7 +2718,7 @@ def testTempCrud(self): r = self.db.get(table, 2) self.assertEqual(r['t'], 'two') self.db.delete(table, r) - r = self.db.query('select n, t from %s order by 1' % table).getresult() + r = self.db.query(f'select n, t from {table} order by 1').getresult() self.assertEqual(r, [(1, 'one'), (3, 'three')]) def testTruncate(self): @@ -2798,16 +2798,16 @@ def testTruncateCascade(self): r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) for n in range(3): - query("insert into test_parent (n) values (%d)" % n) - query("insert into test_child (n) values (%d)" % n) + query(f"insert into test_parent (n) values ({n})") + query(f"insert into test_child (n) values ({n})") r = query(q).getresult()[0] self.assertEqual(r, (3, 3)) truncate('test_parent', cascade=True) r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) for n in range(3): - query("insert into test_parent (n) values (%d)" % n) - query("insert into test_child (n) values (%d)" % n) + query(f"insert into test_parent (n) values ({n})") + query(f"insert into test_child (n) values ({n})") r = query(q).getresult()[0] self.assertEqual(r, (3, 3)) truncate('test_child') @@ -2859,8 +2859,8 @@ def testTruncateOnly(self): self.createTable('test_child_2', 'm smallint) inherits (test_parent_2') for t in '', '_2': for n in range(3): - query("insert into test_parent%s (n) values (1)" % t) - query("insert into test_child%s (n, m) values (2, 3)" % t) + query(f"insert into test_parent{t} (n) values (1)") + query(f"insert into test_child{t} (n, m) values (2, 3)") q = ("select (select count(*) from test_parent)," " (select count(*) from test_child)," " (select count(*) from test_parent_2)," @@ -2883,17 +2883,17 @@ def testTruncateQuoted(self): query = self.db.query table = "test table for truncate()" self.createTable(table, 'n smallint', temporary=False, values=[1] * 3) - q = 'select count(*) from "%s"' % table + q = f'select count(*) from "{table}"' r = query(q).getresult()[0][0] self.assertEqual(r, 3) truncate(table) r = query(q).getresult()[0][0] self.assertEqual(r, 0) for i in range(3): - query('insert into "%s" values (1)' % table) + query(f'insert into "{table}" values (1)') r = query(q).getresult()[0][0] self.assertEqual(r, 3) - truncate('public."%s"' % table) + truncate(f'public."{table}"') r = query(q).getresult()[0][0] self.assertEqual(r, 0) @@ -2975,10 +2975,10 @@ def testGetAsList(self): r = get_as_list(table, what='name', limit=1, scalar=True) self.assertIsInstance(r, list) self.assertEqual(r, expected[:1]) - query('alter table "%s" drop constraint "%s_pkey"' % (table, table)) + query(f'alter table "{table}" drop constraint "{table}_pkey"') self.assertRaises(KeyError, self.db.pkey, table, flush=True) names.insert(1, (1, 'Snowball')) - query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball')) + query(f'insert into "{table}" values ($1, $2)', (1, 'Snowball')) r = get_as_list(table) self.assertIsInstance(r, list) self.assertEqual(r, names) @@ -2990,7 +2990,7 @@ def testGetAsList(self): self.assertIsInstance(r, list) self.assertEqual(set(r), set(names)) # test with arbitrary from clause - from_table = '(select lower(name) as n2 from "%s") as t2' % table + from_table = f'(select lower(name) as n2 from "{table}") as t2' r = get_as_list(from_table) self.assertIsInstance(r, list) r = {row[0] for row in r} @@ -3157,7 +3157,7 @@ def testGetAsDict(self): self.assertEqual(r, expected) self.assertNotIsInstance(self, OrderedDict) # test with arbitrary from clause - from_table = '(select id, lower(name) as n2 from "%s") as t2' % table + from_table = f'(select id, lower(name) as n2 from "{table}") as t2' # primary key must be passed explicitly in this case self.assertRaises(pg.ProgrammingError, get_as_dict, from_table) r = get_as_dict(from_table, 'id') @@ -3165,7 +3165,7 @@ def testGetAsDict(self): expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors) self.assertEqual(r, expected) # test without a primary key - query('alter table "%s" drop constraint "%s_pkey"' % (table, table)) + query(f'alter table "{table}" drop constraint "{table}_pkey"') self.assertRaises(KeyError, self.db.pkey, table, flush=True) self.assertRaises(pg.ProgrammingError, get_as_dict, table) r = get_as_dict(table, keyname='id') @@ -3173,7 +3173,7 @@ def testGetAsDict(self): self.assertIsInstance(r, dict) self.assertEqual(r, expected) r = (1, '#007fff', 'Azure') - query('insert into "%s" values ($1, $2, $3)' % table, r) + query(f'insert into "{table}" values ($1, $2, $3)', r) # the last entry will win expected[1] = r[1:] r = get_as_dict(table, keyname='id') @@ -3971,7 +3971,7 @@ def testTimetz(self): query = self.db.query timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): - tz = '%+03d00' % timezones[timezone] + tz = f'{timezones[timezone]:+03d}00' tzinfo = datetime.strptime(tz, '%z').tzinfo self.db.set_parameter('timezone', timezone) d = time(15, 9, 26, tzinfo=tzinfo) @@ -4023,7 +4023,7 @@ def testTimestamptz(self): query = self.db.query timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): - tz = '%+03d00' % timezones[timezone] + tz = f'{timezones[timezone]:+03d}00' tzinfo = datetime.strptime(tz, '%z').tzinfo self.db.set_parameter('timezone', timezone) for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', @@ -4174,7 +4174,7 @@ def testDbTypesTypecast(self): self.assertIs(dbtypes.get_typecast('int4'), int) self.assertNotIn('circle', dbtypes) self.assertIsNone(dbtypes.get_typecast('circle')) - squared_circle = lambda v: 'Squared Circle: %s' % v # noqa: E731 + squared_circle = lambda v: f'Squared Circle: {v}' # noqa: E731 dbtypes.set_typecast('circle', squared_circle) self.assertIs(dbtypes.get_typecast('circle'), squared_circle) r = self.db.query("select '0,0,1'::circle").getresult()[0][0] @@ -4199,7 +4199,7 @@ def testGetSetTypeCast(self): self.assertIs(get_typecast('bool'), pg.cast_bool) cast_circle = get_typecast('circle') self.addCleanup(set_typecast, 'circle', cast_circle) - squared_circle = lambda v: 'Squared Circle: %s' % v # noqa: E731 + squared_circle = lambda v: f'Squared Circle: {v}' # noqa: E731 self.assertNotIn('circle', dbtypes) set_typecast('circle', squared_circle) self.assertNotIn('circle', dbtypes) @@ -4698,23 +4698,23 @@ def setUpClass(cls): query = db.query for num_schema in range(5): if num_schema: - schema = "s%d" % num_schema - query("drop schema if exists %s cascade" % (schema,)) + schema = f"s{num_schema}" + query(f"drop schema if exists {schema} cascade") try: - query("create schema %s" % (schema,)) + query(f"create schema {schema}") except pg.ProgrammingError: raise RuntimeError( "The test user cannot create schemas.\n" - "Grant create on database %s to the user" - " for running these tests." % dbname) + f"Grant create on database {dbname} to the user" + " for running these tests.") else: schema = "public" - query("drop table if exists %s.t" % (schema,)) - query("drop table if exists %s.t%d" % (schema, num_schema)) - query("create table %s.t %s as select 1 as n, %d as d" - % (schema, cls.with_oids, num_schema)) - query("create table %s.t%d %s as select 1 as n, %d as d" - % (schema, num_schema, cls.with_oids, num_schema)) + query(f"drop table if exists {schema}.t") + query(f"drop table if exists {schema}.t{num_schema}") + query(f"create table {schema}.t {cls.with_oids}" + f" as select 1 as n, {num_schema} as d") + query(f"create table {schema}.t{num_schema} {cls.with_oids}" + f" as select 1 as n, {num_schema} as d") db.close() cls.cls_set_up = True @@ -4724,12 +4724,12 @@ def tearDownClass(cls): query = db.query for num_schema in range(5): if num_schema: - schema = "s%d" % num_schema - query("drop schema %s cascade" % (schema,)) + schema = f"s{num_schema}" + query(f"drop schema {schema} cascade") else: schema = "public" - query("drop table %s.t" % (schema,)) - query("drop table %s.t%d" % (schema, num_schema)) + query(f"drop table {schema}.t") + query(f"drop table {schema}.t{num_schema}") db.close() def setUp(self): @@ -4763,7 +4763,7 @@ def testGetAttnames(self): self.assertEqual(r, result) query("drop table if exists s3.t3m") self.addCleanup(query, "drop table s3.t3m") - query("create table s3.t3m %s as select 1 as m" % (self.with_oids,)) + query(f"create table s3.t3m {self.with_oids} as select 1 as m") result_m = {'m': 'int'} if self.with_oids: result_m['oid'] = 'int' @@ -4824,7 +4824,7 @@ def testQueryInformationSchema(self): q = "column_name" if self.db.server_version < 110000: q += "::text" # old version does not have sql_identifier array - q = "select array_agg(%s) from information_schema.columns" % q + q = f"select array_agg({q}) from information_schema.columns" q += " where table_schema in ('s1', 's2', 's3', 's4')" r = self.db.query(q).onescalar() self.assertIsInstance(r, list) diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 282ec6df..adddc8ce 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -278,7 +278,7 @@ def testParserNested(self): def testParserTooDeeplyNested(self): f = pg.cast_array for n in 3, 5, 9, 12, 16, 32, 64, 256: - r = '%sa,b,c%s' % ('{' * n, '}' * n) + r = '{' * n + 'a,b,c' + '}' * n if n > 16: # hard coded maximum depth self.assertRaises(ValueError, f, r) else: @@ -302,7 +302,7 @@ def testParserCast(self): self.assertEqual(f('{a}', str), ['a']) def cast(s): - return '%s is ok' % s + return f'{s} is ok' self.assertEqual(f('{a}', cast), ['a is ok']) def testParserDelim(self): @@ -528,7 +528,8 @@ def testParserNested(self): def testParserManyElements(self): f = pg.cast_record for n in 3, 5, 9, 12, 16, 32, 64, 256: - r = '(%s)' % ','.join(map(str, range(n))) + r = ','.join(map(str, range(n))) + r = f'({r})' r = f(r, int) self.assertEqual(r, tuple(range(n))) @@ -544,7 +545,7 @@ def testParserCastUniform(self): self.assertEqual(f('(a)', str), ('a',)) def cast(s): - return '%s is ok' % s + return f'{s} is ok' self.assertEqual(f('(a)', cast), ('a is ok',)) def testParserCastNonUniform(self): @@ -571,11 +572,11 @@ def testParserCastNonUniform(self): (1, 'a', 2, 'b', 3, 'c')) def cast1(s): - return '%s is ok' % s + return f'{s} is ok' self.assertEqual(f('(a)', [cast1]), ('a is ok',)) def cast2(s): - return 'and %s is ok, too' % s + return f'and {s} is ok, too' self.assertEqual( f('(a,b)', [cast1, cast2]), ('a is ok', 'and b is ok, too')) self.assertRaises(ValueError, f, '(a)', [cast1, cast2]) @@ -870,9 +871,9 @@ def testEscapeString(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, str) - self.assertEqual(r, u'plain') + self.assertEqual(r, 'plain') r = f("that's cheese") self.assertIsInstance(r, str) self.assertEqual(r, "that''s cheese") @@ -882,9 +883,9 @@ def testEscapeBytea(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, str) - self.assertEqual(r, u'plain') + self.assertEqual(r, 'plain') r = f("that's cheese") self.assertIsInstance(r, str) self.assertEqual(r, "that''s cheese") @@ -894,18 +895,18 @@ def testUnescapeBytea(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(b"das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf-8')) - r = f(u"das is' k\\303\\244se") + self.assertEqual(r, "das is' käse".encode('utf-8')) + r = f("das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf-8')) + self.assertEqual(r, "das is' käse".encode('utf-8')) r = f(b'O\\000ps\\377!') self.assertEqual(r, b'O\x00ps\xff!') - r = f(u'O\\000ps\\377!') + r = f('O\\000ps\\377!') self.assertEqual(r, b'O\x00ps\xff!') diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index 3271686c..bdf3a613 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -38,7 +38,7 @@ def testLargeObjectIntConstants(self): try: value = getattr(pg, name) except AttributeError: - self.fail('Module constant %s is missing' % name) + self.fail(f'Module constant {name} is missing') self.assertIsInstance(value, int) @@ -187,10 +187,10 @@ def testStr(self): self.obj.write(data) oid = self.obj.oid r = str(self.obj) - self.assertEqual(r, 'Opened large object, oid %d' % oid) + self.assertEqual(r, f'Opened large object, oid {oid}') self.obj.close() r = str(self.obj) - self.assertEqual(r, 'Closed large object, oid %d' % oid) + self.assertEqual(r, f'Closed large object, oid {oid}') def testRepr(self): r = repr(self.obj) @@ -260,22 +260,22 @@ def testWrite(self): def testWriteLatin1Bytes(self): read = self.obj.read self.obj.open(pg.INV_WRITE) - self.obj.write(u'käse'.encode('latin1')) + self.obj.write('käse'.encode('latin1')) self.obj.close() self.obj.open(pg.INV_READ) r = read(80) self.assertIsInstance(r, bytes) - self.assertEqual(r.decode('latin1'), u'käse') + self.assertEqual(r.decode('latin1'), 'käse') def testWriteUtf8Bytes(self): read = self.obj.read self.obj.open(pg.INV_WRITE) - self.obj.write(u'käse'.encode('utf8')) + self.obj.write('käse'.encode('utf8')) self.obj.close() self.obj.open(pg.INV_READ) r = read(80) self.assertIsInstance(r, bytes) - self.assertEqual(r.decode('utf8'), u'käse') + self.assertEqual(r.decode('utf8'), 'käse') def testWriteUtf8String(self): read = self.obj.read @@ -285,7 +285,7 @@ def testWriteUtf8String(self): self.obj.open(pg.INV_READ) r = read(80) self.assertIsInstance(r, bytes) - self.assertEqual(r.decode('utf8'), u'käse') + self.assertEqual(r.decode('utf8'), 'käse') def testSeek(self): seek = self.obj.seek @@ -367,7 +367,7 @@ def testUnlinkInexistent(self): unlink = self.obj.unlink self.obj.open(pg.INV_WRITE) self.obj.close() - self.pgcnx.query('select lo_unlink(%d)' % self.obj.oid) + self.pgcnx.query(f'select lo_unlink({self.obj.oid})') self.assertRaises(IOError, unlink) def testSize(self): @@ -446,7 +446,7 @@ def testExportInExistent(self): f = tempfile.NamedTemporaryFile() self.obj.open(pg.INV_WRITE) self.obj.close() - self.pgcnx.query('select lo_unlink(%d)' % self.obj.oid) + self.pgcnx.query(f'select lo_unlink({self.obj.oid})') self.assertRaises(IOError, export, f.name) f.close() diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 6f94cebd..dcc06382 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -88,7 +88,7 @@ def get_handler(self, event=None, arg_dict=None, stop_event=None): handler = self.db.notification_handler( event, callback, arg_dict, 0, stop_event) self.assertEqual(handler.event, event) - self.assertEqual(handler.stop_event, stop_event or 'stop_%s' % event) + self.assertEqual(handler.stop_event, stop_event or f'stop_{event}') self.assertIs(handler.callback, callback) if arg_dict is None: self.assertEqual(handler.arg_dict, {}) @@ -224,7 +224,7 @@ def start_handler(self, event=None, arg_dict=None, self.handler = handler self.assertIsInstance(handler, pg.NotificationHandler) self.assertEqual(handler.event, event) - self.assertEqual(handler.stop_event, stop_event or 'stop_%s' % event) + self.assertEqual(handler.stop_event, stop_event or f'stop_{event}') self.event = handler.event self.assertIs(handler.callback, callback) if arg_dict is None: @@ -277,9 +277,9 @@ def notify_query(self, stop=False, payload=None): if stop: event = self.handler.stop_event self.stopped = True - q = 'notify "%s"' % event + q = f'notify "{event}"' if payload: - q += ", '%s'" % payload + q += f", '{payload}'" arg_dict = self.arg_dict.copy() arg_dict.update(event=event, pid=1, extra=payload or '') self.db.query(q) @@ -370,14 +370,14 @@ def testNotifyQuotedNames(self): def testNotifyWithFivePayloads(self): self.start_handler('gimme_5', {'test': 'Gimme 5'}) for count in range(5): - self.notify_query(payload="Round %d" % count) + self.notify_query(payload=f"Round {count}") self.assertEqual(len(self.sent), 5) self.receive(stop=True) def testReceiveImmediately(self): self.start_handler('immediate', {'test': 'immediate'}) for count in range(3): - self.notify_query(payload="Round %d" % count) + self.notify_query(payload=f"Round {count}") self.receive() self.receive(stop=True) @@ -385,7 +385,7 @@ def testNotifyDistinctInTransaction(self): self.start_handler('test_transaction', {'transaction': True}) self.db.begin() for count in range(3): - self.notify_query(payload='Round %d' % count) + self.notify_query(payload=f'Round {count}') self.db.commit() self.receive(stop=True) diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 01a89247..6062a4fa 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -32,7 +32,7 @@ class test_PyGreSQL(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () connect_kw_args = { - 'database': dbname, 'host': '%s:%d' % (dbhost or '', dbport or -1), + 'database': dbname, 'host': f"{dbhost or ''}:{dbport or -1}", 'user': dbuser, 'password': dbpasswd} lower_func = 'lower' # For stored procedure test @@ -164,7 +164,7 @@ def test_row_factory(self): class TestCursor(pgdb.Cursor): def row_factory(self, row): - return {'column %s' % desc[0]: value + return {f'column {desc[0]}': value for desc, value in zip(self.description, row)} con = self._connect() @@ -306,7 +306,7 @@ def test_description_fields(self): self.assertIsInstance(d, tuple) self.assertEqual(len(d), 7) self.assertIsInstance(d.name, str) - self.assertEqual(d.name, 'col%d' % i) + self.assertEqual(d.name, f'col{i}') self.assertIsInstance(d.type_code, str) self.assertEqual(d.type_code, c[0]) self.assertIsNone(d.display_size) @@ -382,7 +382,7 @@ def test_type_cache_typecast(self): cur = con.cursor() type_cache = con.type_cache self.assertIs(type_cache.get_typecast('int4'), int) - cast_int = lambda v: 'int(%s)' % v # noqa: E731 + cast_int = lambda v: f'int({v})' # noqa: E731 type_cache.set_typecast('int4', cast_int) query = 'select 2::int2, 4::int4, 8::int8' cur.execute(query) @@ -454,7 +454,7 @@ def test_fetch_2_rows(self): cur = con.cursor() cur.execute("set datestyle to iso") cur.execute( - "create table %s (" + f"create table {table} (" "stringtest varchar," "binarytest bytea," "booltest bool," @@ -467,16 +467,16 @@ def test_fetch_2_rows(self): "timetest time," "datetimetest timestamp," "intervaltest interval," - "rowidtest oid)" % table) + "rowidtest oid)") cur.execute("set standard_conforming_strings to on") for s in ('numeric', 'monetary', 'time'): - cur.execute("set lc_%s to 'C'" % s) + cur.execute(f"set lc_{s} to 'C'") for _i in range(2): cur.execute( - "insert into %s values (" - "%%s,%%s,%%s,%%s,%%s,%%s,%%s," - "'%%s'::money,%%s,%%s,%%s,%%s,%%s)" % table, values) - cur.execute("select * from %s" % table) + f"insert into {table} values (" + "%s,%s,%s,%s,%s,%s,%s," + "'%s'::money,%s,%s,%s,%s,%s)", values) + cur.execute(f"select * from {table}") rows = cur.fetchall() self.assertEqual(len(rows), 2) row0 = rows[0] @@ -503,12 +503,12 @@ def test_integrity_error(self): try: cur = con.cursor() cur.execute("set client_min_messages = warning") - cur.execute("create table %s (i int primary key)" % table) - cur.execute("insert into %s values (1)" % table) - cur.execute("insert into %s values (2)" % table) + cur.execute(f"create table {table} (i int primary key)") + cur.execute(f"insert into {table} values (1)") + cur.execute(f"insert into {table} values (2)") self.assertRaises( pgdb.IntegrityError, cur.execute, - "insert into %s values (1)" % table) + f"insert into {table} values (1)") finally: con.close() @@ -517,11 +517,11 @@ def test_update_rowcount(self): con = self._connect() try: cur = con.cursor() - cur.execute("create table %s (i int)" % table) - cur.execute("insert into %s values (1)" % table) - cur.execute("update %s set i=2 where i=2 returning i" % table) + cur.execute(f"create table {table} (i int)") + cur.execute(f"insert into {table} values (1)") + cur.execute(f"update {table} set i=2 where i=2 returning i") self.assertEqual(cur.rowcount, 0) - cur.execute("update %s set i=2 where i=1 returning i" % table) + cur.execute(f"update {table} set i=2 where i=1 returning i") self.assertEqual(cur.rowcount, 1) cur.close() # keep rowcount even if cursor is closed (needed by SQLAlchemy) @@ -552,10 +552,10 @@ def test_float(self): try: cur = con.cursor() cur.execute( - "create table %s (n smallint, floattest float)" % table) + f"create table {table} (n smallint, floattest float)") params = enumerate(values) - cur.executemany("insert into %s values (%%d,%%s)" % table, params) - cur.execute("select floattest from %s order by n" % table) + cur.executemany(f"insert into {table} values (%d,%s)", params) + cur.execute(f"select floattest from {table} order by n") rows = cur.fetchall() self.assertEqual(cur.description[0].type_code, pgdb.FLOAT) self.assertNotEqual(cur.description[0].type_code, pgdb.ARRAY) @@ -589,9 +589,9 @@ def test_datetime(self): try: cur = con.cursor() cur.execute("set timezone = UTC") - cur.execute("create table %s (" + cur.execute(f"create table {table} (" "d date, t time, ts timestamp," - "tz timetz, tsz timestamptz)" % table) + "tz timetz, tsz timestamptz)") for n in range(3): values = [dt.date(), dt.time(), dt, dt.time(), dt] values[3] = values[3].replace(tzinfo=timezone.utc) @@ -609,16 +609,16 @@ def test_datetime(self): pgdb.Timestamp(*(d + t + z))] for datestyle in ('iso', 'postgres, mdy', 'postgres, dmy', 'sql, mdy', 'sql, dmy', 'german'): - cur.execute("set datestyle to %s" % datestyle) + cur.execute(f"set datestyle to {datestyle}") if n != 1: # noinspection PyUnboundLocalVariable cur.execute("select %s,%s,%s,%s,%s", params) row = cur.fetchone() self.assertEqual(row, tuple(values)) cur.execute( - "insert into %s" - " values (%%s,%%s,%%s,%%s,%%s)" % table, params) - cur.execute("select * from %s" % table) + f"insert into {table}" + " values (%s,%s,%s,%s,%s)", params) + cur.execute(f"select * from {table}") d = cur.description for i in range(5): self.assertEqual(d[i].type_code, pgdb.DATETIME) @@ -632,7 +632,7 @@ def test_datetime(self): self.assertEqual(d[4].type_code, pgdb.TIMESTAMP) row = cur.fetchone() self.assertEqual(row, tuple(values)) - cur.execute("truncate table %s" % table) + cur.execute(f"truncate table {table}") finally: con.close() @@ -642,23 +642,22 @@ def test_interval(self): con = self._connect() try: cur = con.cursor() - cur.execute("create table %s (i interval)" % table) + cur.execute(f"create table {table} (i interval)") for n in range(3): if n == 0: # input as objects param = td if n == 1: # input as text - param = '%d days %d seconds %d microseconds ' % ( - td.days, td.seconds, td.microseconds) + param = (f'{td.days} days {td.seconds} seconds' + f' {td.microseconds} microseconds') elif n == 2: # input using type helpers param = pgdb.Interval( td.days, 0, 0, td.seconds, td.microseconds) for intervalstyle in ('sql_standard ', 'postgres', 'postgres_verbose', 'iso_8601'): - cur.execute("set intervalstyle to %s" % intervalstyle) + cur.execute(f"set intervalstyle to {intervalstyle}") # noinspection PyUnboundLocalVariable - cur.execute("insert into %s" - " values (%%s)" % table, [param]) - cur.execute("select * from %s" % table) + cur.execute(f"insert into {table} values (%s)", [param]) + cur.execute(f"select * from {table}") tc = cur.description[0].type_code self.assertEqual(tc, pgdb.DATETIME) self.assertNotEqual(tc, pgdb.STRING) @@ -667,7 +666,7 @@ def test_interval(self): self.assertEqual(tc, pgdb.INTERVAL) row = cur.fetchone() self.assertEqual(row, (td,)) - cur.execute("truncate table %s" % table) + cur.execute(f"truncate table {table}") finally: con.close() @@ -721,15 +720,15 @@ def test_insert_array(self): con = self._connect() try: cur = con.cursor() - cur.execute("create table %s" - " (n smallint, i int[], t text[][])" % table) + cur.execute( + f"create table {table} (n smallint, i int[], t text[][])") params = [(n, v[0], v[1]) for n, v in enumerate(values)] # Note that we must explicit casts because we are inserting # empty arrays. Otherwise this is not necessary. cur.executemany( - "insert into %s values" - " (%%d,%%s::int[],%%s::text[][])" % table, params) - cur.execute("select i, t from %s order by n" % table) + f"insert into {table} values" + " (%d,%s::int[],%s::text[][])", params) + cur.execute(f"select i, t from {table} order by n") d = cur.description self.assertEqual(d[0].type_code, pgdb.ARRAY) self.assertNotEqual(d[0].type_code, pgdb.RECORD) @@ -755,7 +754,7 @@ def test_select_array(self): self.assertEqual(row, values) def test_unicode_list_and_tuple(self): - value = (u'Käse', u'Würstchen') + value = ('Käse', 'Würstchen') con = self._connect() try: cur = con.cursor() @@ -780,11 +779,11 @@ def test_insert_record(self): con = self._connect() cur = con.cursor() try: - cur.execute("create type %s as (name varchar, age int)" % record) - cur.execute("create table %s (n smallint, r %s)" % (table, record)) + cur.execute(f"create type {record} as (name varchar, age int)") + cur.execute(f"create table {table} (n smallint, r {record})") params = enumerate(values) - cur.executemany("insert into %s values (%%d,%%s)" % table, params) - cur.execute("select r from %s order by n" % table) + cur.executemany(f"insert into {table} values (%d,%s)", params) + cur.execute(f"select r from {table} order by n") type_code = cur.description[0].type_code self.assertEqual(type_code, record) self.assertEqual(type_code, pgdb.RECORD) @@ -796,8 +795,8 @@ def test_insert_record(self): self.assertEqual(con.type_cache[columns[1].type], 'int4') rows = cur.fetchall() finally: - cur.execute('drop table %s' % table) - cur.execute('drop type %s' % record) + cur.execute(f'drop table {table}') + cur.execute(f'drop type {record}') con.close() self.assertEqual(len(rows), len(values)) rows = [row[0] for row in rows] @@ -832,9 +831,9 @@ def test_custom_type(self): cur = con.cursor() params = enumerate(values) # params have __pg_repr__ method cur.execute( - 'create table "%s" (n smallint, b bit varying(7))' % table) - cur.executemany("insert into %s values (%%s,%%s)" % table, params) - cur.execute("select * from %s" % table) + f'create table "{table}" (n smallint, b bit varying(7))') + cur.executemany(f"insert into {table} values (%s,%s)", params) + cur.execute(f"select * from {table}") rows = cur.fetchall() finally: con.close() @@ -845,7 +844,7 @@ def test_custom_type(self): params = (1, object()) # an object that cannot be handled self.assertRaises( pgdb.InterfaceError, cur.execute, - "insert into %s values (%%s,%%s)" % table, params) + f"insert into {table} values (%s,%s)", params) finally: con.close() @@ -887,7 +886,7 @@ def test_global_typecast(self): try: query = 'select 2::int2, 4::int4, 8::int8' self.assertIs(pgdb.get_typecast('int4'), int) - cast_int = lambda v: 'int(%s)' % v # noqa: E731 + cast_int = lambda v: f'int({v})' # noqa: E731 pgdb.set_typecast('int4', cast_int) con = self._connect() try: @@ -974,23 +973,23 @@ def test_set_typecast_for_arrays(self): def test_unicode_with_utf8(self): table = self.table_prefix + 'booze' - s = u"He wes Leovenaðes sone — liðe him be Drihten" + s = "He wes Leovenaðes sone — liðe him be Drihten" con = self._connect() cur = con.cursor() try: - cur.execute("create table %s (t text)" % table) + cur.execute(f"create table {table} (t text)") try: cur.execute("set client_encoding=utf8") - cur.execute(u"select '%s'" % s) + cur.execute(f"select '{s}'") except Exception: self.skipTest("database does not support utf8") output1 = cur.fetchone()[0] - cur.execute("insert into %s values (%%s)" % table, (s,)) - cur.execute("select * from %s" % table) + cur.execute(f"insert into {table} values (%s)", (s,)) + cur.execute(f"select * from {table}") output2 = cur.fetchone()[0] - cur.execute("select t = '%s' from %s" % (s, table)) + cur.execute(f"select t = '{s}' from {table}") output3 = cur.fetchone()[0] - cur.execute("select t = %%s from %s" % table, (s,)) + cur.execute(f"select t = %s from {table}", (s,)) output4 = cur.fetchone()[0] finally: con.close() @@ -1005,23 +1004,23 @@ def test_unicode_with_utf8(self): def test_unicode_with_latin1(self): table = self.table_prefix + 'booze' - s = u"Ehrt den König seine Würde, ehret uns der Hände Fleiß." + s = "Ehrt den König seine Würde, ehret uns der Hände Fleiß." con = self._connect() try: cur = con.cursor() - cur.execute("create table %s (t text)" % table) + cur.execute(f"create table {table} (t text)") try: cur.execute("set client_encoding=latin1") - cur.execute(u"select '%s'" % s) + cur.execute(f"select '{s}'") except Exception: self.skipTest("database does not support latin1") output1 = cur.fetchone()[0] - cur.execute("insert into %s values (%%s)" % table, (s,)) - cur.execute("select * from %s" % table) + cur.execute(f"insert into {table} values (%s)", (s,)) + cur.execute(f"select * from {table}") output2 = cur.fetchone()[0] - cur.execute("select t = '%s' from %s" % (s, table)) + cur.execute(f"select t = '{s}' from {table}") output3 = cur.fetchone()[0] - cur.execute("select t = %%s from %s" % table, (s,)) + cur.execute(f"select t = %s from {table}", (s,)) output4 = cur.fetchone()[0] finally: con.close() @@ -1040,11 +1039,10 @@ def test_bool(self): con = self._connect() try: cur = con.cursor() - cur.execute( - "create table %s (n smallint, booltest bool)" % table) + cur.execute(f"create table {table} (n smallint, booltest bool)") params = enumerate(values) - cur.executemany("insert into %s values (%%s,%%s)" % table, params) - cur.execute("select booltest from %s order by n" % table) + cur.executemany(f"insert into {table} values (%s,%s)", params) + cur.execute(f"select booltest from {table} order by n") rows = cur.fetchall() self.assertEqual(cur.description[0].type_code, pgdb.BOOL) finally: @@ -1073,12 +1071,12 @@ def test_json(self): try: cur = con.cursor() try: - cur.execute("create table %s (jsontest json)" % table) + cur.execute(f"create table {table} (jsontest json)") except pgdb.ProgrammingError: self.skipTest('database does not support json') params = (pgdb.Json(inval),) - cur.execute("insert into %s values (%%s)" % table, params) - cur.execute("select jsontest from %s" % table) + cur.execute(f"insert into {table} values (%s)", params) + cur.execute(f"select jsontest from {table}") outval = cur.fetchone()[0] self.assertEqual(cur.description[0].type_code, pgdb.JSON) finally: @@ -1093,12 +1091,12 @@ def test_jsonb(self): try: cur = con.cursor() try: - cur.execute("create table %s (jsonbtest jsonb)" % table) + cur.execute(f"create table {table} (jsonbtest jsonb)") except pgdb.ProgrammingError: self.skipTest('database does not support jsonb') params = (pgdb.Json(inval),) - cur.execute("insert into %s values (%%s)" % table, params) - cur.execute("select jsonbtest from %s" % table) + cur.execute(f"insert into {table} values (%s)", params) + cur.execute(f"select jsonbtest from {table}") outval = cur.fetchone()[0] self.assertEqual(cur.description[0].type_code, pgdb.JSON) finally: @@ -1135,8 +1133,8 @@ def test_fetchall_with_various_sizes(self): for n in (1, 3, 5, 7, 10, 100, 1000): cur = con.cursor() try: - cur.execute('select n, n::text as s, n %% 2 = 1 as b' - ' from generate_series(1, %d) as s(n)' % n) + cur.execute('select n, n::text as s, n % 2 = 1 as b' + f' from generate_series(1, {n}) as s(n)') res = cur.fetchall() self.assertEqual(len(res), n, res) self.assertEqual(len(res[0]), 3) @@ -1212,13 +1210,13 @@ def test_transaction(self): con1.commit() con2 = self._connect() cur2 = con2.cursor() - cur2.execute("select name from %s" % table) + cur2.execute(f"select name from {table}") self.assertIsNone(cur2.fetchone()) - cur1.execute("insert into %s values('Schlafly')" % table) - cur2.execute("select name from %s" % table) + cur1.execute(f"insert into {table} values('Schlafly')") + cur2.execute(f"select name from {table}") self.assertIsNone(cur2.fetchone()) con1.commit() - cur2.execute("select name from %s" % table) + cur2.execute(f"select name from {table}") self.assertEqual(cur2.fetchone(), ('Schlafly',)) con2.close() con1.close() @@ -1231,10 +1229,10 @@ def test_autocommit(self): self.executeDDL1(cur1) con2 = self._connect() cur2 = con2.cursor() - cur2.execute("select name from %s" % table) + cur2.execute(f"select name from {table}") self.assertIsNone(cur2.fetchone()) - cur1.execute("insert into %s values('Shmaltz Pastrami')" % table) - cur2.execute("select name from %s" % table) + cur1.execute(f"insert into {table} values('Shmaltz Pastrami')") + cur2.execute(f"select name from {table}") self.assertEqual(cur2.fetchone(), ('Shmaltz Pastrami',)) con2.close() con1.close() @@ -1247,32 +1245,32 @@ def test_connection_as_contextmanager(self): try: cur = con.cursor() if autocommit: - cur.execute("truncate table %s" % table) + cur.execute(f"truncate table {table}") else: cur.execute( - "create table %s (n smallint check(n!=4))" % table) + f"create table {table} (n smallint check(n!=4))") with con: - cur.execute("insert into %s values (1)" % table) - cur.execute("insert into %s values (2)" % table) + cur.execute(f"insert into {table} values (1)") + cur.execute(f"insert into {table} values (2)") try: with con: - cur.execute("insert into %s values (3)" % table) - cur.execute("insert into %s values (4)" % table) + cur.execute(f"insert into {table} values (3)") + cur.execute(f"insert into {table} values (4)") except con.IntegrityError as error: self.assertTrue('check' in str(error).lower()) with con: - cur.execute("insert into %s values (5)" % table) - cur.execute("insert into %s values (6)" % table) + cur.execute(f"insert into {table} values (5)") + cur.execute(f"insert into {table} values (6)") try: with con: - cur.execute("insert into %s values (7)" % table) - cur.execute("insert into %s values (8)" % table) + cur.execute(f"insert into {table} values (7)") + cur.execute(f"insert into {table} values (8)") raise ValueError('transaction should rollback') except ValueError as error: self.assertEqual(str(error), 'transaction should rollback') with con: - cur.execute("insert into %s values (9)" % table) - cur.execute("select * from %s order by 1" % table) + cur.execute(f"insert into {table} values (9)") + cur.execute(f"select * from {table} order by 1") rows = cur.fetchall() rows = [row[0] for row in rows] finally: diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 769065ab..d461825c 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -100,7 +100,7 @@ class TestCopy(unittest.TestCase): @staticmethod def connect(): - host = '%s:%d' % (dbhost or '', dbport or -1) + host = f"{dbhost or ''}:{dbport or -1}" return pgdb.connect(database=dbname, host=host, user=dbuser, password=dbpasswd) @@ -163,11 +163,11 @@ def tearDown(self): @property def data_text(self): - return ''.join('%d\t%s\n' % row for row in self.data) + return ''.join('{}\t{}\n'.format(*row) for row in self.data) @property def data_csv(self): - return ''.join('%d,%s\n' % row for row in self.data) + return ''.join('{},{}\n'.format(*row) for row in self.data) def truncate_table(self): self.cursor.execute("truncate table copytest") @@ -259,7 +259,7 @@ def test_input_iterable_invalid(self): self.assertRaises(IOError, self.copy_from, [None]) def test_input_iterable_with_newlines(self): - self.copy_from('%s\n' % row for row in self.data_text.splitlines()) + self.copy_from(f'{row}\n' for row in self.data_text.splitlines()) self.check_table() def test_input_iterable_bytes(self): @@ -268,7 +268,7 @@ def test_input_iterable_bytes(self): self.check_table() def test_sep(self): - stream = ('%d-%s' % row for row in self.data) + stream = ('{}-{}'.format(*row) for row in self.data) self.copy_from(stream, sep='-') self.check_table() @@ -311,7 +311,7 @@ def test_csv(self): self.check_table() def test_csv_with_sep(self): - stream = ('%d;"%s"\n' % row for row in self.data) + stream = ('{};"{}"\n'.format(*row) for row in self.data) self.copy_from(stream, format='csv', sep=';') self.check_table() self.check_rowcount() @@ -326,7 +326,7 @@ def test_binary_with_sep(self): ValueError, self.copy_from, '', format='binary', sep='\t') def test_binary_with_unicode(self): - self.assertRaises(ValueError, self.copy_from, u'', format='binary') + self.assertRaises(ValueError, self.copy_from, '', format='binary') def test_query(self): self.assertRaises(ValueError, self.cursor.copy_from, '', "select null") @@ -441,10 +441,10 @@ def test_decode(self): def test_sep(self): ret = list(self.copy_to(sep='-')) - self.assertEqual(ret, ['%d-%s\n' % row for row in self.data]) + self.assertEqual(ret, ['{}-{}\n'.format(*row) for row in self.data]) def test_null(self): - data = ['%d\t%s\n' % row for row in self.data] + data = ['{}\t{}\n'.format(*row) for row in self.data] self.cursor.execute('insert into copytest values(4, null)') try: ret = list(self.copy_to()) @@ -457,8 +457,8 @@ def test_null(self): self.cursor.execute('delete from copytest where id=4') def test_columns(self): - data_id = ''.join('%d\n' % row[0] for row in self.data) - data_name = ''.join('%s\n' % row[1] for row in self.data) + data_id = ''.join(f'{row[0]}\n' for row in self.data) + data_name = ''.join(f'{row[1]}\n' for row in self.data) ret = ''.join(self.copy_to(columns='id')) self.assertEqual(ret, data_id) ret = ''.join(self.copy_to(columns=['id'])) @@ -513,7 +513,7 @@ def test_query(self): rows = list(ret) self.assertEqual(len(rows), 1) self.assertIsInstance(rows[0], str) - self.assertEqual(rows[0], '%s!\n' % self.data[1][1]) + self.assertEqual(rows[0], f'{self.data[1][1]}!\n') self.check_rowcount(1) def test_file(self): diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 0193165a..a497914b 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -108,7 +108,7 @@ class TestDbApi20Tutorial(unittest.TestCase): def setUp(self): """Setup test tables or empty them if they already exist.""" - host = '%s:%d' % (dbhost or '', dbport or -1) + host = f"{dbhost or ''}:{dbport or -1}" con = connect(database=dbname, host=host, user=dbuser, password=dbpasswd) cur = con.cursor() From b87e87bbde0d3f8e959d86820db4e15c9bd31b57 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 13:32:33 +0200 Subject: [PATCH 134/194] Add next row extension to DBAPI 20 conformance test --- tests/dbapi20.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 798bbc49..32045fa4 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -461,6 +461,48 @@ def test_fetchone(self): finally: con.close() + def test_next(self): + """Extension for getting the next row""" + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur, 'next'): + return + + # cursor.next should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error, cur.next) + + # cursor.next should raise an Error if called after + # executing a query that cannot return rows + self.executeDDL1(cur) + self.assertRaises(self.driver.Error, cur.next) + + # cursor.next should return None if a query retrieves no rows + cur.execute(f'select name from {self.table_prefix}booze') + self.assertRaises(StopIteration, cur.next) + self.assertIn(cur.rowcount, (-1, 0)) + + # cursor.next should raise an Error if called after + # executing a query that cannot return rows + cur.execute(f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") + self.assertRaises(self.driver.Error, cur.next) + + cur.execute(f'select name from {self.table_prefix}booze') + r = cur.next() + self.assertEqual( + len(r), 1, + 'cursor.fetchone should have retrieved a single row') + self.assertEqual( + r[0], 'Victoria Bitter', + 'cursor.next retrieved incorrect data') + # cursor.next should raise StopIteration if no more rows available + self.assertRaises(StopIteration, cur.next) + self.assertIn(cur.rowcount, (-1, 1)) + finally: + con.close() + samples = [ 'Carlton Cold', 'Carlton Draft', From 324b8fc9cae976e2b8e42e69aace59ff36fb2881 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 13:53:22 +0200 Subject: [PATCH 135/194] Some more string formatting modernization --- docs/conf.py | 2 +- docs/contents/pg/adaptation.rst | 4 ++-- docs/contents/pg/connection.rst | 2 +- docs/contents/pg/db_wrapper.rst | 2 +- docs/contents/pg/module.rst | 8 ++++---- docs/contents/pgdb/adaptation.rst | 4 ++-- docs/contents/postgres/advanced.rst | 6 +++--- setup.py | 6 +++--- 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 933c4e38..f5789d29 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,7 +68,7 @@ html_theme = 'alabaster' html_static_path = ['_static'] -html_title = 'PyGreSQL %s' % version +html_title = f'PyGreSQL {version}' html_logo = '_static/pygresql.png' html_favicon = '_static/favicon.ico' diff --git a/docs/contents/pg/adaptation.rst b/docs/contents/pg/adaptation.rst index c5d0a795..de82cbfa 100644 --- a/docs/contents/pg/adaptation.rst +++ b/docs/contents/pg/adaptation.rst @@ -231,7 +231,7 @@ our values:: ... self.price = price ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) But when we try to insert an instance of this class in the same way, we @@ -246,7 +246,7 @@ PostgreSQL by adding a "magic" method with the name ``__pg_str__``, like so:: ... ... ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) ... ... def __pg_str__(self, typ): diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 237e25a8..1adf29d1 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -157,7 +157,7 @@ Examples:: s = con1.query("begin; set transaction isolation level repeatable read;" "select pg_export_snapshot();").single() con2.query("begin; set transaction isolation level repeatable read;" - "set transaction snapshot '%s'" % (s,)) + f"set transaction snapshot '{s}'") q1 = con1.send_query("select a,b,c from x where d=e") q2 = con2.send_query("select e,f from y where g") r1 = q1.getresult() diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 5d587f97..68d33c65 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -16,7 +16,7 @@ The preferred way to use this module is as follows:: for r in db.query( # just for example "SELECT foo, bar FROM foo_bar_table WHERE foo !~ bar" ).dictresult(): - print('%(foo)s %(bar)s' % r) + print('{foo} {bar}'.format(**r)) This class can be subclassed as in this example:: diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index 203ada03..2dc26d5f 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -289,8 +289,8 @@ which takes connection properties into account. Example:: name = input("Name? ") - phone = con.query("select phone from employees where name='%s'" - % escape_string(name)).getresult() + phone = con.query("select phone from employees" + f" where name='{escape_string(name)}'").singlescalar() escape_bytea -- escape binary data for use within SQL ----------------------------------------------------- @@ -315,8 +315,8 @@ which takes connection properties into account. Example:: picture = open('garfield.gif', 'rb').read() - con.query("update pictures set img='%s' where name='Garfield'" - % escape_bytea(picture)) + con.query(f"update pictures set img='{escape_bytea(picture)}'" + " where name='Garfield'") unescape_bytea -- unescape data that has been retrieved as text --------------------------------------------------------------- diff --git a/docs/contents/pgdb/adaptation.rst b/docs/contents/pgdb/adaptation.rst index ebb36e5b..ac649a21 100644 --- a/docs/contents/pgdb/adaptation.rst +++ b/docs/contents/pgdb/adaptation.rst @@ -209,7 +209,7 @@ to hold our values, like this one:: ... self.price = price ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) But when we try to insert an instance of this class in the same way, we @@ -231,7 +231,7 @@ with the name ``__pg_repr__``, like this:: ... ... ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) ... ... def __pg_repr__(self): diff --git a/docs/contents/postgres/advanced.rst b/docs/contents/postgres/advanced.rst index e3e2ab10..d7627312 100644 --- a/docs/contents/postgres/advanced.rst +++ b/docs/contents/postgres/advanced.rst @@ -27,7 +27,7 @@ all data fields from cities):: ... "'Las Vegas', 2.583E+5, 2174", ... "'Mariposa', 1200, 1953"]), ... ('capitals', [ - ... "'Sacramento',3.694E+5,30,'CA'", + ... "'Sacramento', 3.694E+5,30, 'CA'", ... "'Madison', 1.913E+5, 845, 'WI'"])] Now, let's populate the tables:: @@ -37,11 +37,11 @@ Now, let's populate the tables:: ... "'Las Vegas', 2.583E+5, 2174" ... "'Mariposa', 1200, 1953"], ... 'capitals', [ - ... "'Sacramento',3.694E+5,30,'CA'", + ... "'Sacramento', 3.694E+5,30, 'CA'", ... "'Madison', 1.913E+5, 845, 'WI'"]] >>> for table, rows in data: ... for row in rows: - ... query("INSERT INTO %s VALUES (%s)" % (table, row)) + ... query(f"INSERT INTO {table} VALUES (row)") >>> print(query("SELECT * FROM cities")) name |population|altitude -------------+----------+-------- diff --git a/setup.py b/setup.py index 456e3b5e..08d43dae 100755 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ if not (3, 7) <= sys.version_info[:2] < (4, 0): raise Exception( - "Sorry, PyGreSQL %s does not support this Python version" % version) + f"Sorry, PyGreSQL {version} does not support this Python version") # For historical reasons, PyGreSQL does not install itself as a single # "pygresql" package, but as two top-level modules "pg", providing the @@ -69,12 +69,12 @@ def pg_config(s): """Retrieve information about installed version of PostgreSQL.""" - f = os.popen('pg_config --%s' % s) + f = os.popen(f'pg_config --{s}') d = f.readline().strip() if f.close() is not None: raise Exception("pg_config tool is not available.") if not d: - raise Exception("Could not get %s information." % s) + raise Exception(f"Could not get {s} information.") return d From e3398d5ec3b919de4e81988615302421893124cc Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 14:03:26 +0200 Subject: [PATCH 136/194] Add more context to errors --- docs/contents/postgres/func.rst | 2 +- pg.py | 18 +++++++++--------- pgdb.py | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/contents/postgres/func.rst b/docs/contents/postgres/func.rst index 9d0f5967..3bfcfd98 100644 --- a/docs/contents/postgres/func.rst +++ b/docs/contents/postgres/func.rst @@ -62,7 +62,7 @@ Before we create more sophisticated functions, let's populate an EMP table:: ... "'Bill', 4200, 36, 'shoe'", ... "'Ginger', 4800, 30, 'candy'"] >>> for emp in emps: - ... query("INSERT INTO EMP VALUES (%s)" % emp) + ... query(f"INSERT INTO EMP VALUES ({emp})") Every INSERT statement will return a '1' indicating that it has inserted one row into the EMP table. diff --git a/pg.py b/pg.py index 50f22425..923d2743 100644 --- a/pg.py +++ b/pg.py @@ -950,7 +950,7 @@ def __missing__(self, typ): but returns None when no special cast function exists. """ if not isinstance(typ, str): - raise TypeError('Invalid type: {typ}') + raise TypeError(f'Invalid type: {typ}') cast = self.defaults.get(typ) if cast: # store default for faster access @@ -2257,8 +2257,8 @@ def update(self, table, row=None, **kw): else: # try using the primary key try: keyname = self.pkey(table, True) - except KeyError: # the table has no primary key - raise _prg_error(f'Table {table} has no primary key') + except KeyError as e: # the table has no primary key + raise _prg_error(f'Table {table} has no primary key') from e # check whether all key columns have values if not set(keyname).issubset(row): raise KeyError('Missing value for primary key in row') @@ -2359,8 +2359,8 @@ def upsert(self, table, row=None, **kw): names, values = ', '.join(names), ', '.join(values) try: keyname = self.pkey(table, True) - except KeyError: - raise _prg_error(f'Table {table} has no primary key') + except KeyError as e: + raise _prg_error(f'Table {table} has no primary key') from e target = ', '.join(col(k) for k in keyname) update = [] keyname = set(keyname) @@ -2444,8 +2444,8 @@ def delete(self, table, row=None, **kw): else: # try using the primary key try: keyname = self.pkey(table, True) - except KeyError: # the table has no primary key - raise _prg_error(f'Table {table} has no primary key') + except KeyError as e: # the table has no primary key + raise _prg_error(f'Table {table} has no primary key') from e # check whether all key columns have values if not set(keyname).issubset(row): raise KeyError('Missing value for primary key in row') @@ -2612,8 +2612,8 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, if not keyname: try: keyname = self.pkey(table, True) - except (KeyError, ProgrammingError): - raise _prg_error(f'Table {table} has no primary key') + except (KeyError, ProgrammingError) as e: + raise _prg_error(f'Table {table} has no primary key') from e if isinstance(keyname, str): keyname = [keyname] elif not isinstance(keyname, (list, tuple)): diff --git a/pgdb.py b/pgdb.py index 5e218b42..44b6a83e 100644 --- a/pgdb.py +++ b/pgdb.py @@ -865,9 +865,9 @@ def _quote(self, value): return f'({v})' try: # noinspection PyUnresolvedReferences value = value.__pg_repr__() - except AttributeError: + except AttributeError as e: raise InterfaceError( - f'Do not know how to adapt type {type(value)}') + f'Do not know how to adapt type {type(value)}') from e if isinstance(value, (tuple, list)): value = self._quote(value) return value @@ -965,8 +965,8 @@ def executemany(self, operation, seq_of_parameters): self._src.execute(sql) except DatabaseError: raise # database provides error message - except Exception: - raise _op_error("Can't start transaction") + except Exception as e: + raise _op_error("Can't start transaction") from e else: self._dbcnx._tnx = True for parameters in seq_of_parameters: @@ -983,7 +983,7 @@ def executemany(self, operation, seq_of_parameters): # noinspection PyTypeChecker raise _db_error(f"Error in '{sql}': '{err}'", InterfaceError) except Exception as err: - raise _op_error(f"Internal error in '{sql}': {err}") + raise _op_error(f"Internal error in '{sql}': {err}") from err # then initialize result raw count and description if self._src.resulttype == RESULT_DQL: self._description = True # fetch on demand @@ -1027,7 +1027,7 @@ def fetchmany(self, size=None, keep=False): except DatabaseError: raise except Error as err: - raise _db_error(str(err)) + raise _db_error(str(err)) from err row_factory = self.row_factory coltypes = self.coltypes if len(result) > 5: From 816ec354723da1571a682e3215b06e4c6a9c7b1d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 20:01:48 +0200 Subject: [PATCH 137/194] Replace flake8 with ruff and use pyproject.toml Also fix some minor issues that were detected by ruff. Remove announcements from static docs and inline about page. --- .bumpversion.cfg | 16 ---- .devcontainer/devcontainer.json | 1 + .devcontainer/provision.sh | 10 ++- .flake8 | 4 - .github/workflows/lint.yml | 2 +- docs/about.rst | 42 +++++++++- docs/about.txt | 41 ---------- docs/announce.rst | 26 ------ docs/conf.py | 2 +- docs/contents/changelog.rst | 6 ++ docs/download/index.rst | 6 +- docs/index.rst | 1 - pg.py | 131 +++++++++++++++++++++---------- pgdb.py | 63 +++++++++------ pyproject.toml | 75 ++++++++++++++++++ setup.py | 89 +++------------------ tests/config.py | 4 +- tests/dbapi20.py | 5 +- tests/test_classic.py | 24 +++--- tests/test_classic_connection.py | 31 ++++---- tests/test_classic_dbwrapper.py | 43 +++++----- tests/test_classic_functions.py | 10 +-- tests/test_classic_largeobj.py | 16 ++-- tests/test_dbapi20.py | 2 +- tests/test_dbapi20_copy.py | 29 ++++--- tests/test_tutorial.py | 4 +- tox.ini | 17 +++- 27 files changed, 372 insertions(+), 328 deletions(-) delete mode 100644 .flake8 delete mode 100644 docs/about.txt delete mode 100644 docs/announce.rst create mode 100644 pyproject.toml diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 89aec55e..769d02cf 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -12,22 +12,6 @@ serialize = search = version = '{current_version}' replace = version = '{new_version}' -[bumpversion:file (head):setup.py] -search = PyGreSQL version {current_version} -replace = PyGreSQL version {new_version} - [bumpversion:file:docs/conf.py] search = version = release = '{current_version}' replace = version = release = '{new_version}' - -[bumpversion:file:docs/about.txt] -search = PyGreSQL {current_version} -replace = PyGreSQL {new_version} - -[bumpversion:file:docs/announce.rst] -search = PyGreSQL version {current_version} -replace = PyGreSQL version {new_version} - -[bumpversion:file (text):docs/announce.rst] -search = Release {current_version} of PyGreSQL -replace = Release {new_version} of PyGreSQL diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index c1374910..b9fbaaeb 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -45,6 +45,7 @@ "njpwerner.autodocstring", "redhat.vscode-yaml", "eamodio.gitlens", + "charliermarsh.ruff", "streetsidesoftware.code-spell-checker", "lextudio.restructuredtext" ] diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index b47abb8c..a42337b8 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -27,9 +27,15 @@ sudo apt-get install -y python3.9 python3.9-dev python3.9-distutils sudo apt-get install -y python3.10 python3.10-dev python3.10-distutils sudo apt-get install -y python3.11 python3.11-dev python3.11-distutils -# install testing tool +# install build and testing tool -sudo apt-get install -y tox +python3.7 -m pip install build +python3.8 -m pip install build +python3.9 -m pip install build +python3.10 -m pip install build +python3.11 -m pip install build + +sudo apt-get install -y tox python3-poetry # install PostgreSQL client tools diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 3f6e0a3c..00000000 --- a/.flake8 +++ /dev/null @@ -1,4 +0,0 @@ -[flake8] -ignore = F403,F405,W503 -exclude = .git,.tox,.venv,build,dist,docs -max-line-length = 79 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 54ae2fd3..40f5299e 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,5 +22,5 @@ jobs: with: python-version: 3.11 - name: Run quality checks - run: tox -e flake8,docs + run: tox -e ruff,docs timeout-minutes: 5 diff --git a/docs/about.rst b/docs/about.rst index 3e61d030..8235e5cc 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -1,4 +1,44 @@ About PyGreSQL ============== -.. include:: about.txt \ No newline at end of file +**PyGreSQL** is an *open-source* `Python `_ module +that interfaces to a `PostgreSQL `_ database. +It wraps the lower level C API library libpq to allow easy use of the +powerful PostgreSQL features from Python. + + | This software is copyright © 1995, Pascal Andre. + | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. + | Further modifications are copyright © 2009-2023 by the PyGreSQL team. + | For licensing details, see the full :doc:`copyright`. + +**PostgreSQL** is a highly scalable, SQL compliant, open source +object-relational database management system. With more than 20 years +of development history, it is quickly becoming the de facto database +for enterprise level open source solutions. +Best of all, PostgreSQL's source code is available under the most liberal +open source license: the BSD license. + +**Python** Python is an interpreted, interactive, object-oriented +programming language. It is often compared to Tcl, Perl, Scheme or Java. +Python combines remarkable power with very clear syntax. It has modules, +classes, exceptions, very high level dynamic data types, and dynamic typing. +There are interfaces to many system calls and libraries, as well as to +various windowing systems (X11, Motif, Tk, Mac, MFC). New built-in modules +are easily written in C or C++. Python is also usable as an extension +language for applications that need a programmable interface. +The Python implementation is copyrighted but freely usable and distributable, +even for commercial use. + +**PyGreSQL** is a Python module that interfaces to a PostgreSQL database. +It wraps the lower level C API library libpq to allow easy use of the +powerful PostgreSQL features from Python. + +PyGreSQL is developed and tested on a NetBSD system, but it also runs on +most other platforms where PostgreSQL and Python is running. It is based +on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). +D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with +version 2.0 and serves as the "BDFL" of PyGreSQL. + +The current version PyGreSQL |version| needs PostgreSQL 10 to 15, and Python +3.7 to 3.11. If you need to support older PostgreSQL or Python versions, +you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/about.txt b/docs/about.txt deleted file mode 100644 index 04f615e1..00000000 --- a/docs/about.txt +++ /dev/null @@ -1,41 +0,0 @@ -**PyGreSQL** is an *open-source* `Python `_ module -that interfaces to a `PostgreSQL `_ database. -It wraps the lower level C API library libpq to allow easy use of the -powerful PostgreSQL features from Python. - - | This software is copyright © 1995, Pascal Andre. - | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2023 by the PyGreSQL team. - | For licensing details, see the full :doc:`copyright`. - -**PostgreSQL** is a highly scalable, SQL compliant, open source -object-relational database management system. With more than 20 years -of development history, it is quickly becoming the de facto database -for enterprise level open source solutions. -Best of all, PostgreSQL's source code is available under the most liberal -open source license: the BSD license. - -**Python** Python is an interpreted, interactive, object-oriented -programming language. It is often compared to Tcl, Perl, Scheme or Java. -Python combines remarkable power with very clear syntax. It has modules, -classes, exceptions, very high level dynamic data types, and dynamic typing. -There are interfaces to many system calls and libraries, as well as to -various windowing systems (X11, Motif, Tk, Mac, MFC). New built-in modules -are easily written in C or C++. Python is also usable as an extension -language for applications that need a programmable interface. -The Python implementation is copyrighted but freely usable and distributable, -even for commercial use. - -**PyGreSQL** is a Python module that interfaces to a PostgreSQL database. -It wraps the lower level C API library libpq to allow easy use of the -powerful PostgreSQL features from Python. - -PyGreSQL is developed and tested on a NetBSD system, but it also runs on -most other platforms where PostgreSQL and Python is running. It is based -on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). -D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with -version 2.0 and serves as the "BDFL" of PyGreSQL. - -The current version PyGreSQL 6.0 needs PostgreSQL 10 to 15, and Python -3.7 to 3.11. If you need to support older PostgreSQL or Python versions, -you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst deleted file mode 100644 index d0a5f19c..00000000 --- a/docs/announce.rst +++ /dev/null @@ -1,26 +0,0 @@ -====================== -PyGreSQL Announcements -====================== - -------------------------------- -Release of PyGreSQL version 6.0 -------------------------------- - -Release 6.0 of PyGreSQL. - -It is available at: https://pypi.org/project/PyGreSQL/. - -If you are running NetBSD, look in the packages directory under databases. -There is also a package in the FreeBSD ports collection. - -Please refer to `changelog.txt `_ -for things that have changed in this version. - -This version has been built and unit tested on: - - Ubuntu - - Windows 7 and 10 with Visual Studio - - PostgreSQL 10 to 15 (32 and 64bit) - - Python 3.7 to 3.11 (32 and 64bit) - -| D'Arcy J.M. Cain -| darcy@PyGreSQL.org diff --git a/docs/conf.py b/docs/conf.py index f5789d29..0f95ab1b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,7 +10,7 @@ author = 'The PyGreSQL team' copyright = '2023, ' + author -version = release = '5.2.5' +version = release = '6.0' language = 'en' diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index bc8322f4..e2b68425 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,12 @@ ChangeLog ========= +Version 6.0 (to be released) +---------------------------- +- Removed support for Python versions older than 3.7 (released June 2017) + and PostgreSQL older than version 10 (released October 2017). +- Modernized code and tools for development, testing, linting and building. + Version 5.2.5 (2023-08-28) -------------------------- - This version officially supports the new Python 3.11 and PostgreSQL 15. diff --git a/docs/download/index.rst b/docs/download/index.rst index c4735826..88bf77b0 100644 --- a/docs/download/index.rst +++ b/docs/download/index.rst @@ -3,10 +3,8 @@ Download information .. include:: download.rst -News, Changes and Future Development ------------------------------------- - -See the :doc:`../announce` for current news. +Changes and Future Development +------------------------------ For a list of all changes in the current version |version| and in past versions, have a look at the :doc:`../contents/changelog`. diff --git a/docs/index.rst b/docs/index.rst index c40103a8..88292059 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,7 +6,6 @@ Welcome to PyGreSQL about copyright - announce download/index contents/index community/index diff --git a/pg.py b/pg.py index 923d2743..70de429e 100644 --- a/pg.py +++ b/pg.py @@ -22,7 +22,7 @@ try: from _pg import version -except ImportError as e: +except ImportError as e: # noqa: F841 import os libpq = 'libpq.' if os.name == 'nt': @@ -55,29 +55,69 @@ # import objects from extension module from _pg import ( - Error, Warning, - DataError, DatabaseError, - IntegrityError, InterfaceError, InternalError, - InvalidResultError, MultipleResultsError, - NoResultError, NotSupportedError, - OperationalError, ProgrammingError, - INV_READ, INV_WRITE, - POLLING_OK, POLLING_FAILED, POLLING_READING, POLLING_WRITING, - SEEK_CUR, SEEK_END, SEEK_SET, - TRANS_ACTIVE, TRANS_IDLE, TRANS_INERROR, - TRANS_INTRANS, TRANS_UNKNOWN, - cast_array, cast_hstore, cast_record, - connect, escape_bytea, escape_string, unescape_bytea, - get_array, get_bool, get_bytea_escaped, - get_datestyle, get_decimal, get_decimal_point, - get_defbase, get_defhost, get_defopt, get_defport, get_defuser, - get_jsondecode, get_pqlib_version, - set_array, set_bool, set_bytea_escaped, - set_datestyle, set_decimal, set_decimal_point, - set_defbase, set_defhost, set_defopt, - set_defpasswd, set_defport, set_defuser, - set_jsondecode, set_query_helpers, - version) + INV_READ, + INV_WRITE, + POLLING_FAILED, + POLLING_OK, + POLLING_READING, + POLLING_WRITING, + SEEK_CUR, + SEEK_END, + SEEK_SET, + TRANS_ACTIVE, + TRANS_IDLE, + TRANS_INERROR, + TRANS_INTRANS, + TRANS_UNKNOWN, + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + InvalidResultError, + MultipleResultsError, + NoResultError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, + cast_array, + cast_hstore, + cast_record, + connect, + escape_bytea, + escape_string, + get_array, + get_bool, + get_bytea_escaped, + get_datestyle, + get_decimal, + get_decimal_point, + get_defbase, + get_defhost, + get_defopt, + get_defport, + get_defuser, + get_jsondecode, + get_pqlib_version, + set_array, + set_bool, + set_bytea_escaped, + set_datestyle, + set_decimal, + set_decimal_point, + set_defbase, + set_defhost, + set_defopt, + set_defpasswd, + set_defport, + set_defuser, + set_jsondecode, + set_query_helpers, + unescape_bytea, + version, +) __version__ = version @@ -112,19 +152,18 @@ import select import warnings import weakref - -from datetime import date, time, datetime, timedelta +from collections import OrderedDict, namedtuple +from datetime import date, datetime, time, timedelta from decimal import Decimal -from math import isnan, isinf -from collections import namedtuple, OrderedDict +from functools import lru_cache, partial from inspect import signature +from json import dumps as jsonencode +from json import loads as jsondecode +from math import isinf, isnan from operator import itemgetter -from functools import lru_cache, partial from re import compile as regex -from json import loads as jsondecode, dumps as jsonencode -from uuid import UUID from typing import Dict, List, Union # noqa: F401 - +from uuid import UUID # Auxiliary classes and functions that are independent of a DB connection: @@ -174,6 +213,7 @@ def _quote(cls, s): return s def __str__(self): + """Create a printable representation of the hstore value.""" q = self._quote return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) @@ -182,10 +222,12 @@ class Json: """Wrapper class for marking Json values.""" def __init__(self, obj, encode=None): + """Initialize the JSON object.""" self.obj = obj self.encode = encode or jsonencode def __str__(self): + """Create a printable representation of the JSON object.""" obj = self.obj if isinstance(obj, str): return obj @@ -313,6 +355,7 @@ class Adapter: _re_array_escape = _re_record_escape = regex(r'(["\\])') def __init__(self, db): + """Initialize the adapter object with the given connection.""" self.db = weakref.proxy(db) @classmethod @@ -1124,7 +1167,7 @@ class DbTypes(dict): def __init__(self, db): """Initialize type cache for connection.""" - super(DbTypes, self).__init__() + super().__init__() self._db = weakref.proxy(db) self._regtypes = False self._typecasts = Typecasts() @@ -1315,7 +1358,7 @@ def _prg_error(msg): # The notification handler -class NotificationHandler(object): +class NotificationHandler: """A PostgreSQL client-side asynchronous notification handler.""" def __init__(self, db, event, callback=None, @@ -1348,6 +1391,7 @@ def __init__(self, db, event, callback=None, self.timeout = timeout def __del__(self): + """Delete the notification handler.""" self.unlisten() def close(self): @@ -1440,7 +1484,10 @@ def __call__(self): def pgnotify(*args, **kw): - """Same as NotificationHandler, under the traditional name.""" + """Create a notification handler. + + Same as NotificationHandler, under the traditional name. + """ warnings.warn("pgnotify is deprecated, use NotificationHandler instead", DeprecationWarning, stacklevel=2) return NotificationHandler(*args, **kw) @@ -1454,7 +1501,7 @@ class DB: db = None # invalid fallback for underlying connection def __init__(self, *args, **kw): - """Create a new connection + """Create a new connection. You can pass either the connection parameters or an existing _pg or pgdb connection. This allows you to use the methods @@ -1519,6 +1566,7 @@ def __init__(self, *args, **kw): self.debug = None def __getattr__(self, name): + """Get the specified attritbute of the connection.""" # All undefined members are same as in underlying connection: if self.db: return getattr(self.db, name) @@ -1526,6 +1574,7 @@ def __getattr__(self, name): raise _int_error('Connection is not valid') def __dir__(self): + """List all attributes of the connection.""" # Custom dir function including the attributes of the connection: attrs = set(self.__class__.__dict__) attrs.update(self.__dict__) @@ -1547,6 +1596,7 @@ def __exit__(self, et, ev, tb): self.rollback() def __del__(self): + """Delete the connection.""" try: db = self.db except AttributeError: @@ -1565,7 +1615,7 @@ def __del__(self): # Auxiliary methods def _do_debug(self, *args): - """Print a debug message""" + """Print a debug message.""" if self.debug: s = '\n'.join(str(arg) for arg in args) if isinstance(self.debug, str): @@ -1918,7 +1968,7 @@ def describe_prepared(self, name=None): return self.db.describe_prepared(name) def delete_prepared(self, name=None): - """Delete a prepared SQL statement + """Delete a prepared SQL statement. This deallocates a previously prepared SQL statement with the given name, or deallocates all prepared statements if you do not specify a @@ -2275,8 +2325,7 @@ def update(self, table, row=None, **kw): keyname = set(keyname) for n in attnames: if n in row and n not in keyname and n not in generated: - values.append('{} = {}'.format( - col(n), adapt(row[n], attnames[n]))) + values.append(f'{col(n)} = {adapt(row[n], attnames[n])}') if not values: return row values = ', '.join(values) @@ -2294,7 +2343,7 @@ def update(self, table, row=None, **kw): return row def upsert(self, table, row=None, **kw): - """Insert a row into a database table with conflict resolution + """Insert a row into a database table with conflict resolution. This method inserts a row into a table, but instead of raising a ProgrammingError exception in case a row with the same primary key diff --git a/pgdb.py b/pgdb.py index 44b6a83e..f61522bb 100644 --- a/pgdb.py +++ b/pgdb.py @@ -66,7 +66,7 @@ try: from _pg import version -except ImportError as e: +except ImportError as e: # noqa: F841 import os libpq = 'libpq.' if os.name == 'nt': @@ -99,14 +99,24 @@ # import objects from extension module from _pg import ( - Error, Warning, - DataError, DatabaseError, - IntegrityError, InterfaceError, InternalError, - NotSupportedError, OperationalError, ProgrammingError, - cast_array, cast_hstore, cast_record, RESULT_DQL, - connect, unescape_bytea, - version) + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, + cast_array, + cast_hstore, + cast_record, + connect, + unescape_bytea, + version, +) __version__ = version @@ -127,17 +137,18 @@ 'get_typecast', 'set_typecast', 'reset_typecast', 'version', '__version__'] -from datetime import date, time, datetime, timedelta -from time import localtime -from decimal import Decimal as StdDecimal -from uuid import UUID as Uuid -from math import isnan, isinf from collections import namedtuple from collections.abc import Iterable -from inspect import signature +from datetime import date, datetime, time, timedelta +from decimal import Decimal as StdDecimal from functools import lru_cache, partial +from inspect import signature +from json import dumps as jsonencode +from json import loads as jsondecode +from math import isinf, isnan from re import compile as regex -from json import loads as jsondecode, dumps as jsonencode +from time import localtime +from uuid import UUID as Uuid Decimal = StdDecimal @@ -623,7 +634,7 @@ class TypeCache(dict): def __init__(self, cnx): """Initialize type cache for connection.""" - super(TypeCache, self).__init__() + super().__init__() self._escape_string = cnx.escape_string self._src = cnx.source() self._typecasts = LocalTypecasts() @@ -726,7 +737,7 @@ class _quotedict(dict): def __getitem__(self, key): # noinspection PyUnresolvedReferences - return self.quote(super(_quotedict, self).__getitem__(key)) + return self.quote(super().__getitem__(key)) # *** Error Messages *** @@ -777,7 +788,7 @@ def set_row_factory_size(maxsize): # *** Cursor Object *** -class Cursor(object): +class Cursor: """Cursor object.""" def __init__(self, dbcnx): @@ -1369,7 +1380,7 @@ def build_row_factory(self): # *** Connection Objects *** -class Connection(object): +class Connection: """Connection object.""" # expose the exceptions as attributes on the connection object @@ -1576,25 +1587,28 @@ class Type(frozenset): """ def __new__(cls, values): + """Create new type object.""" if isinstance(values, str): values = values.split() - return super(Type, cls).__new__(cls, values) + return super().__new__(cls, values) def __eq__(self, other): + """Check whether types are considered equal.""" if isinstance(other, str): if other.startswith('_'): other = other[1:] return other in self else: - return super(Type, self).__eq__(other) + return super().__eq__(other) def __ne__(self, other): + """Check whether types are not considered equal.""" if isinstance(other, str): if other.startswith('_'): other = other[1:] return other not in self else: - return super(Type, self).__ne__(other) + return super().__ne__(other) class ArrayType: @@ -1741,6 +1755,7 @@ def _quote(cls, s): return s def __str__(self): + """Create a printable representation of the hstore value.""" q = self._quote return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) @@ -1749,10 +1764,12 @@ class Json: """Construct a wrapper for holding an object serializable to JSON.""" def __init__(self, obj, encode=None): + """Initialize the JSON object.""" self.obj = obj self.encode = encode or jsonencode def __str__(self): + """Create a printable representation of the JSON object.""" obj = self.obj if isinstance(obj, str): return obj @@ -1763,9 +1780,11 @@ class Literal: """Construct a wrapper for holding a literal SQL string.""" def __init__(self, sql): + """Initialize literal SQL string.""" self.sql = sql def __str__(self): + """Return a printable representation of the SQL string.""" return self.sql __pg_repr__ = __str__ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..b1a184cc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,75 @@ +[project] +name = "PyGreSQL" +version = "6.0" +requires-python = ">=3.7" +authors = [ + {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, + {name = "Christoph Zwerschke", email = "cito@online.de"}, +] +description = "Python PostgreSQL interfaces" +readme = "README.rst" +keywords = ["pygresql", "postgresql", "database", "api", "dbapi"] +classifiers = [ + "Development Status :: 6 - Mature", + "Intended Audience :: Developers", + "License :: OSI Approved :: PostgreSQL License", + "Operating System :: OS Independent", + "Programming Language :: C", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development :: Libraries :: Python Modules", +] + +[project.license] +file = "LICENSE.txt" + +[project.urls] +homepage = "https://pygresql.github.io/" +documentation = "https://pygresql.github.io/contents/" +source = "https://github.com/PyGreSQL/PyGreSQL" +issues = "https://github.com/PyGreSQL/PyGreSQL/issues/" +changelog = "https://pygresql.github.io/contents/changelog.html" +download = "https://pygresql.github.io/download/" +"mailing list" = "https://mail.vex.net/mailman/listinfo/pygresql" + +[tool.ruff] +line-length = 79 +select = [ + "E", # pycodestyle + "F", # pyflakes + "UP", # pyupgrade + "D", # pydocstyle +] +exclude = [ + "__pycache__", + "__pypackages__", + ".git", + ".tox", + ".venv", + ".devcontainer", + ".vscode", + "docs", + "build", + "dist", + "local", + "venv", +] + +[tool.ruff.per-file-ignores] +"tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107"] + +[tool.setuptools] +py-modules = ["pg", "pgdb"] +license-files = ["LICENSE.txt"] + +[build-system] +requires = ["setuptools>=68", "wheel>=0.41"] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 08d43dae..29a84bf8 100755 --- a/setup.py +++ b/setup.py @@ -1,41 +1,11 @@ #!/usr/bin/python -# -# PyGreSQL - a Python interface for the PostgreSQL database. -# -# Copyright (c) 2023 by the PyGreSQL Development Team -# -# Please see the LICENSE.TXT file for specific restrictions. - -"""Setup script for PyGreSQL version 6.0 - -PyGreSQL is an open-source Python module that interfaces to a -PostgreSQL database. It wraps the lower level C API library libpq -to allow easy use of the powerful PostgreSQL features from Python. - -Authors and history: -* PyGreSQL written 1997 by D'Arcy J.M. Cain -* based on code written 1995 by Pascal Andre -* setup script created 2000 by Mark Alexander -* improved 2000 by Jeremy Hylton -* improved 2001 by Gerhard Haering -* improved 2006 to 2018 by Christoph Zwerschke - -Prerequisites to be installed: -* Python including devel package (header files and distutils) -* PostgreSQL libs and devel packages (header file of the libpq client) -* PostgreSQL pg_config tool (usually included in the devel package) - (the Windows installer has it as part of the database server feature) - -PyGreSQL currently supports Python versions 3.7 to 3.11, -and PostgreSQL versions 10 to 15. - -Use as follows: -python setup.py build_ext # to build the module -python setup.py install # to install it - -See docs.python.org/doc/install/ for more information on -using distutils to install Python programs. +"""Driver script for building PyGreSQL using setuptools. + +You can build the PyGreSQL distribution like this: + + pip install build + python -m build -C strict -C memory-size """ import os @@ -43,15 +13,12 @@ import re import sys import warnings -try: - from setuptools import setup -except ImportError: - from distutils.core import setup -from distutils.extension import Extension -from distutils.command.build_ext import build_ext from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext + version = '6.0' if not (3, 7) <= sys.version_info[:2] < (4, 0): @@ -63,10 +30,8 @@ # classic interface, and "pgdb" for the modern DB-API 2.0 interface. # These two top-level Python modules share the same C extension "_pg". -py_modules = ['pg', 'pgdb'] c_sources = ['pgmodule.c'] - def pg_config(s): """Retrieve information about installed version of PostgreSQL.""" f = os.popen(f'pg_config --{s}') @@ -118,6 +83,7 @@ def get_compiler(self): return self.compiler or get_default_compiler() def initialize_options(self): + """Initialize the supported options with default values.""" build_ext.initialize_options(self) self.strict = False self.memory_size = None @@ -157,45 +123,10 @@ def finalize_options(self): setup( name="PyGreSQL", version=version, - description="Python PostgreSQL Interfaces", - long_description=__doc__.split('\n\n', 2)[1], # first passage - long_description_content_type='text/plain', - keywords="pygresql postgresql database api dbapi", - author="D'Arcy J. M. Cain", - author_email="darcy@PyGreSQL.org", - url="http://www.pygresql.org", - download_url="http://www.pygresql.org/download/", - project_urls={ - "Documentation": "https://pygresql.org/contents/", - "Issue Tracker": "https://github.com/PyGreSQL/PyGreSQL/issues/", - "Mailing List": "https://mail.vex.net/mailman/listinfo/pygresql", - "Source Code": "https://github.com/PyGreSQL/PyGreSQL"}, - platforms=["any"], - license="PostgreSQL", - py_modules=py_modules, ext_modules=[Extension( '_pg', c_sources, include_dirs=include_dirs, library_dirs=library_dirs, define_macros=define_macros, undef_macros=undef_macros, libraries=libraries, extra_compile_args=extra_compile_args)], - zip_safe=False, cmdclass=dict(build_ext=build_pg_ext), - test_suite='tests.discover', - classifiers=[ - "Development Status :: 6 - Mature", - "Intended Audience :: Developers", - "License :: OSI Approved :: PostgreSQL License", - "Operating System :: OS Independent", - "Programming Language :: C", - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - "Programming Language :: SQL", - "Topic :: Database", - "Topic :: Database :: Front-Ends", - "Topic :: Software Development :: Libraries :: Python Modules"] ) diff --git a/tests/config.py b/tests/config.py index 6e2ebd3c..acd8559a 100644 --- a/tests/config.py +++ b/tests/config.py @@ -26,9 +26,9 @@ dbport = int(dbport) try: - from .LOCAL_PyGreSQL import * # noqa: F401 + from .LOCAL_PyGreSQL import * # noqa: F403 except (ImportError, ValueError): try: - from LOCAL_PyGreSQL import * # noqa: F401 + from LOCAL_PyGreSQL import * # noqa: F403 except ImportError: pass diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 32045fa4..e76e5fb9 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -9,9 +9,8 @@ __version__ = '1.15.0' -import unittest import time - +import unittest from typing import Any, Dict, Tuple @@ -462,7 +461,7 @@ def test_fetchone(self): con.close() def test_next(self): - """Extension for getting the next row""" + """Test extension for getting the next row.""" con = self._connect() try: cur = con.cursor() diff --git a/tests/test_classic.py b/tests/test_classic.py index 6319d5d5..d6763074 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -1,17 +1,21 @@ #!/usr/bin/python import unittest - from functools import partial -from time import sleep from threading import Thread +from time import sleep from pg import ( - DB, NotificationHandler, - Error, DatabaseError, IntegrityError, - NotSupportedError, ProgrammingError) + DB, + DatabaseError, + Error, + IntegrityError, + NotificationHandler, + NotSupportedError, + ProgrammingError, +) -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser def open_db(): @@ -28,7 +32,7 @@ class UtilityTest(unittest.TestCase): @classmethod def setUpClass(cls): - """Recreate test tables and schemas""" + """Recreate test tables and schemas.""" db = open_db() try: db.query("DROP VIEW _test_vschema") @@ -56,7 +60,7 @@ def setUpClass(cls): db.close() def setUp(self): - """Setup test tables or empty them if they already exist.""" + """Set up test tables or empty them if they already exist.""" db = open_db() db.query("TRUNCATE TABLE _test_schema") for t in ('_test1', '_test2'): @@ -64,12 +68,12 @@ def setUp(self): db.close() def test_invalid_name(self): - """Make sure that invalid table names are caught""" + """Make sure that invalid table names are caught.""" db = open_db() self.assertRaises(NotSupportedError, db.get_attnames, 'x.y.z') def test_schema(self): - """Does it differentiate the same table name in different schemas""" + """Check differentiation of same table name in different schemas.""" db = open_db() # see if they differentiate the table names properly self.assertEqual( diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index f7ca2a46..4436239d 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -9,18 +9,17 @@ These tests need a database to test against. """ -import unittest +import os import threading import time -import os - +import unittest from collections import namedtuple from collections.abc import Iterable from decimal import Decimal import pg # the module under test -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser windows = os.name == 'nt' @@ -284,7 +283,7 @@ def testAllQueryMembers(self): def testMethodEndcopy(self): try: self.connection.endcopy() - except IOError: + except OSError: pass def testMethodClose(self): @@ -864,7 +863,7 @@ def testGetresultUtf8(self): self.skipTest("database does not support utf8") self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('utf8') + q = q.encode() # pass the query as bytes v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) @@ -879,7 +878,7 @@ def testDictresultUtf8(self): self.skipTest("database does not support utf8") self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('utf8') + q = q.encode() v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -2013,7 +2012,7 @@ def testInserttableByteValues(self): 0.0, 0.0, 0.0, '0.0', c, 'bäd', 'bäd', "käse сыр pont-l'évêque") row_bytes = tuple( - s.encode('utf-8') if isinstance(s, str) else s + s.encode() if isinstance(s, str) else s for s in row_unicode) data = [row_bytes] * 2 self.c.inserttable('test', data) @@ -2098,7 +2097,7 @@ def testInserttableFromQuery(self): None, 'c', 'v4', None, 'text')]) def testInserttableSpecialChars(self): - class S(object): + class S: def __repr__(self): return s @@ -2187,7 +2186,7 @@ def testPutlineBytesAndUnicode(self): self.skipTest('database does not support utf8') query("copy test from stdin") try: - putline("47\tkäse\n".encode('utf8')) + putline("47\tkäse\n".encode()) putline("35\twürstel\n") finally: self.c.endcopy() @@ -2212,7 +2211,7 @@ def testGetline(self): finally: try: self.c.endcopy() - except IOError: + except OSError: pass def testGetlineBytesAndUnicode(self): @@ -2222,7 +2221,7 @@ def testGetlineBytesAndUnicode(self): query("select 'käse+würstel'") except (pg.DataError, pg.NotSupportedError): self.skipTest('database does not support utf8') - data = [(54, 'käse'.encode('utf8')), (73, 'würstel')] + data = [(54, 'käse'.encode()), (73, 'würstel')] self.c.inserttable('test', data) query("copy test to stdout") try: @@ -2236,7 +2235,7 @@ def testGetlineBytesAndUnicode(self): finally: try: self.c.endcopy() - except IOError: + except OSError: pass def testParameterChecks(self): @@ -2715,9 +2714,9 @@ def testEscapeString(self): r = f('plain') self.assertIsInstance(r, str) self.assertEqual(r, 'plain') - r = f("das is' käse".encode('utf-8')) + r = f("das is' käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is'' käse".encode('utf-8')) + self.assertEqual(r, "das is'' käse".encode()) r = f("that's cheesy") self.assertIsInstance(r, str) self.assertEqual(r, "that''s cheesy") @@ -2733,7 +2732,7 @@ def testEscapeBytea(self): r = f('plain') self.assertIsInstance(r, str) self.assertEqual(r, 'plain') - r = f("das is' käse".encode('utf-8')) + r = f("das is' käse".encode()) self.assertIsInstance(r, bytes) self.assertEqual(r, b"das is'' k\\\\303\\\\244se") r = f("that's cheesy") diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 79c962a4..8e64949d 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -9,24 +9,23 @@ These tests need a database to test against. """ -import unittest -import os -import sys import gc import json +import os +import sys import tempfile - -import pg # the module under test - +import unittest from collections import OrderedDict +from datetime import date, datetime, time, timedelta from decimal import Decimal -from datetime import date, time, datetime, timedelta from io import StringIO -from uuid import UUID -from time import strftime from operator import itemgetter +from time import strftime +from uuid import UUID + +import pg # the module under test -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser debug = False # let DB wrapper print debugging output @@ -337,7 +336,7 @@ def testMethodQueryDataError(self): def testMethodEndcopy(self): try: self.db.endcopy() - except IOError: + except OSError: pass def testMethodClose(self): @@ -507,9 +506,9 @@ def testEscapeLiteral(self): r = f("plain") self.assertIsInstance(r, str) self.assertEqual(r, "'plain'") - r = f("that's käse".encode('utf-8')) + r = f("that's käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, "'that''s käse'".encode('utf-8')) + self.assertEqual(r, "'that''s käse'".encode()) r = f("that's käse") self.assertIsInstance(r, str) self.assertEqual(r, "'that''s käse'") @@ -526,9 +525,9 @@ def testEscapeIdentifier(self): r = f("plain") self.assertIsInstance(r, str) self.assertEqual(r, '"plain"') - r = f("that's käse".encode('utf-8')) + r = f("that's käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, '"that\'s käse"'.encode('utf-8')) + self.assertEqual(r, '"that\'s käse"'.encode()) r = f("that's käse") self.assertIsInstance(r, str) self.assertEqual(r, '"that\'s käse"') @@ -545,9 +544,9 @@ def testEscapeString(self): r = f("plain") self.assertIsInstance(r, str) self.assertEqual(r, "plain") - r = f("that's käse".encode('utf-8')) + r = f("that's käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, "that''s käse".encode('utf-8')) + self.assertEqual(r, "that''s käse".encode()) r = f("that's käse") self.assertIsInstance(r, str) self.assertEqual(r, "that''s käse") @@ -564,7 +563,7 @@ def testEscapeBytea(self): r = f('plain') self.assertIsInstance(r, str) self.assertEqual(r, '\\x706c61696e') - r = f("das is' käse".encode('utf-8')) + r = f("das is' käse".encode()) self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x64617320697327206bc3a47365') r = f("das is' käse") @@ -582,10 +581,10 @@ def testUnescapeBytea(self): self.assertEqual(r, b'plain') r = f(b"das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is' käse".encode('utf8')) + self.assertEqual(r, "das is' käse".encode()) r = f("das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is' käse".encode('utf8')) + self.assertEqual(r, "das is' käse".encode()) self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!') self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e') self.assertEqual(f(r'\\x746861742773206be47365'), @@ -4320,11 +4319,11 @@ def setUpClass(cls): db = DB() cls.regtypes = not db.use_regtypes() db.close() - super(TestDBClassNonStdOpts, cls).setUpClass() + super().setUpClass() @classmethod def tearDownClass(cls): - super(TestDBClassNonStdOpts, cls).tearDownClass() + super().tearDownClass() cls.reset_option('jsondecode') cls.reset_option('bool') cls.reset_option('array') diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index adddc8ce..914450f5 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -9,15 +9,13 @@ These tests do not need a database to test against. """ -import unittest - import json import re +import unittest +from datetime import timedelta import pg # the module under test -from datetime import timedelta - class TestHasConnect(unittest.TestCase): """Test existence of basic pg module functions.""" @@ -900,10 +898,10 @@ def testUnescapeBytea(self): self.assertEqual(r, b'plain') r = f(b"das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is' käse".encode('utf-8')) + self.assertEqual(r, "das is' käse".encode()) r = f("das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is' käse".encode('utf-8')) + self.assertEqual(r, "das is' käse".encode()) r = f(b'O\\000ps\\377!') self.assertEqual(r, b'O\x00ps\xff!') r = f('O\\000ps\\377!') diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index bdf3a613..039ca51f 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -9,13 +9,13 @@ These tests need a database to test against. """ -import unittest -import tempfile import os +import tempfile +import unittest import pg # the module under test -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser windows = os.name == 'nt' @@ -151,11 +151,11 @@ def tearDown(self): if self.obj.oid: try: self.obj.close() - except (SystemError, IOError): + except (SystemError, OSError): pass try: self.obj.unlink() - except (SystemError, IOError): + except (SystemError, OSError): pass del self.obj try: @@ -270,12 +270,12 @@ def testWriteLatin1Bytes(self): def testWriteUtf8Bytes(self): read = self.obj.read self.obj.open(pg.INV_WRITE) - self.obj.write('käse'.encode('utf8')) + self.obj.write('käse'.encode()) self.obj.close() self.obj.open(pg.INV_READ) r = read(80) self.assertIsInstance(r, bytes) - self.assertEqual(r.decode('utf8'), 'käse') + self.assertEqual(r.decode(), 'käse') def testWriteUtf8String(self): read = self.obj.read @@ -285,7 +285,7 @@ def testWriteUtf8String(self): self.obj.open(pg.INV_READ) r = read(80) self.assertIsInstance(r, bytes) - self.assertEqual(r.decode('utf8'), 'käse') + self.assertEqual(r.decode(), 'käse') def testSeek(self): seek = self.obj.seek diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 6062a4fa..8522fbc3 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -24,7 +24,7 @@ def __init__(self, value): self.value = value def __pg_repr__(self): - return "B'{0:b}'".format(self.value) + return f"B'{self.value:b}'" class test_PyGreSQL(dbapi20.DatabaseAPI20Test): diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index d461825c..c4e8dd74 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -10,24 +10,23 @@ """ import unittest - from collections.abc import Iterable import pgdb # the module under test -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser class InputStream: def __init__(self, data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode() self.data = data or b'' self.sizes = [] def __str__(self): - data = self.data.decode('utf-8') + data = self.data.decode() return data def __len__(self): @@ -50,7 +49,7 @@ def __init__(self): self.sizes = [] def __str__(self): - data = self.data.decode('utf-8') + data = self.data.decode() return data def __len__(self): @@ -58,7 +57,7 @@ def __len__(self): def write(self, data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode() self.data += data self.sizes.append(len(data)) @@ -188,10 +187,10 @@ class TestCopyFrom(TestCopy): """Test the copy_from method.""" def tearDown(self): - super(TestCopyFrom, self).tearDown() + super().tearDown() self.setUp() self.truncate_table() - super(TestCopyFrom, self).tearDown() + super().tearDown() def copy_from(self, stream, **options): return self.cursor.copy_from(stream, 'copytest', **options) @@ -202,7 +201,7 @@ def data_file(self): def test_bad_params(self): call = self.cursor.copy_from - call('0\t', 'copytest'), self.cursor + call('0\t', 'copytest') call('1\t', 'copytest', format='text', sep='\t', null='', columns=['id', 'name']) self.assertRaises(TypeError, call) @@ -247,7 +246,7 @@ def test_input_bytes(self): self.copy_from(b'42\tHello, world!') self.assertEqual(self.table_data, [(42, 'Hello, world!')]) self.truncate_table() - self.copy_from(self.data_text.encode('utf-8')) + self.copy_from(self.data_text.encode()) self.check_table() def test_input_iterable(self): @@ -263,7 +262,7 @@ def test_input_iterable_with_newlines(self): self.check_table() def test_input_iterable_bytes(self): - self.copy_from(row.encode('utf-8') + self.copy_from(row.encode() for row in self.data_text.splitlines()) self.check_table() @@ -368,7 +367,7 @@ class TestCopyTo(TestCopy): @classmethod def setUpClass(cls): - super(TestCopyTo, cls).setUpClass() + super().setUpClass() con = cls.connect() cur = con.cursor() cur.execute("set client_encoding=utf8") @@ -423,7 +422,7 @@ def test_generator_bytes(self): self.assertEqual(len(rows), 3) rows = b''.join(rows) self.assertIsInstance(rows, bytes) - self.assertEqual(rows, self.data_text.encode('utf-8')) + self.assertEqual(rows, self.data_text.encode()) def test_rowcount_increment(self): ret = self.copy_to() @@ -436,7 +435,7 @@ def test_decode(self): ret_decoded = ''.join(self.copy_to(decode=True)) self.assertIsInstance(ret_raw, bytes) self.assertIsInstance(ret_decoded, str) - self.assertEqual(ret_decoded, ret_raw.decode('utf-8')) + self.assertEqual(ret_decoded, ret_raw.decode()) self.check_rowcount() def test_sep(self): @@ -521,7 +520,7 @@ def test_file(self): ret = self.copy_to(stream) self.assertIs(ret, self.cursor) self.assertEqual(str(stream), self.data_text) - data = self.data_text.encode('utf-8') + data = self.data_text.encode() sizes = [len(row) + 1 for row in data.splitlines()] self.assertEqual(stream.sizes, sizes) self.check_rowcount() diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index a497914b..1a43ab7d 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -12,7 +12,7 @@ class TestClassicTutorial(unittest.TestCase): """Test the First Steps Tutorial for the classic interface.""" def setUp(self): - """Setup test tables or empty them if they already exist.""" + """Set up test tables or empty them if they already exist.""" db = DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) db.query("set datestyle to 'iso'") db.query("set default_with_oids=false") @@ -107,7 +107,7 @@ class TestDbApi20Tutorial(unittest.TestCase): """Test the First Steps Tutorial for the DB-API 2.0 interface.""" def setUp(self): - """Setup test tables or empty them if they already exist.""" + """Set up test tables or empty them if they already exist.""" host = f"{dbhost or ''}:{dbport or -1}" con = connect(database=dbname, host=host, user=dbuser, password=dbpasswd) diff --git a/tox.ini b/tox.ini index 23fb9379..9ddc3a75 100644 --- a/tox.ini +++ b/tox.ini @@ -1,13 +1,13 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11},flake8,docs +envlist = py3{7,8,9,10,11},ruff,docs -[testenv:flake8] +[testenv:ruff] basepython = python3.11 -deps = flake8>=6,<7 +deps = ruff>=0.0.287 commands = - flake8 setup.py pg.py pgdb.py tests + ruff setup.py pg.py pgdb.py tests [testenv:docs] basepython = python3.11 @@ -16,6 +16,15 @@ deps = commands = sphinx-build -b html -nEW docs docs/_build/html +[testenv:build] +basepython = python3.11 +deps = + setuptools>=68 + wheel>=0.41 + build>=0.10 +commands = + python -m build -n -C strict -C memory-size + [testenv] passenv = PG* From 8e0859aa8a6b963a3ee7c9687f9e4123cf0bce59 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 22:02:12 +0200 Subject: [PATCH 138/194] Remove deprecated pgnotify function in pg module --- pg.py | 10 --------- tests/test_classic_notification.py | 34 ++---------------------------- 2 files changed, 2 insertions(+), 42 deletions(-) diff --git a/pg.py b/pg.py index 70de429e..6d4db899 100644 --- a/pg.py +++ b/pg.py @@ -1483,16 +1483,6 @@ def __call__(self): self.callback(None) -def pgnotify(*args, **kw): - """Create a notification handler. - - Same as NotificationHandler, under the traditional name. - """ - warnings.warn("pgnotify is deprecated, use NotificationHandler instead", - DeprecationWarning, stacklevel=2) - return NotificationHandler(*args, **kw) - - # The actual PostgreSQL database connection interface: class DB: diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index dcc06382..12d0dee8 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -11,12 +11,12 @@ import unittest import warnings -from time import sleep from threading import Thread +from time import sleep import pg # the module under test -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser debug = False # let DB wrapper print debugging output @@ -29,36 +29,6 @@ def DB(): return db -class TestPyNotifyAlias(unittest.TestCase): - """Test alternative ways of creating a NotificationHandler.""" - - def callback(self): - self.fail('Callback should not be called in this test') - - def testPgNotify(self): - db = DB() - arg_dict = {} - args = ('test_event', self.callback, arg_dict) - kwargs = dict(timeout=2, stop_event='test_stop') - with warnings.catch_warnings(record=True) as warn_msgs: - warnings.simplefilter("always") - # noinspection PyDeprecation - handler1 = pg.pgnotify(db, *args, **kwargs) - self.assertEqual(len(warn_msgs), 1) - warn_msg = warn_msgs[0] - self.assertTrue(issubclass(warn_msg.category, DeprecationWarning)) - self.assertIn('deprecated', str(warn_msg.message)) - self.assertIsInstance(handler1, pg.NotificationHandler) - handler2 = db.notification_handler(*args, **kwargs) - self.assertIsInstance(handler2, pg.NotificationHandler) - self.assertIs(handler1.db, handler2.db) - self.assertEqual(handler1.event, handler2.event) - self.assertIs(handler1.callback, handler2.callback) - self.assertIs(handler1.arg_dict, handler2.arg_dict) - self.assertEqual(handler1.timeout, handler2.timeout) - self.assertEqual(handler1.stop_event, handler2.stop_event) - - class TestSyncNotification(unittest.TestCase): """Test notification handler running in the same thread.""" From 7d055b415e25f25137e484caf8b244025eac4196 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 22:30:45 +0200 Subject: [PATCH 139/194] Remove deprecated pg.Query.ntuples method --- docs/contents/changelog.rst | 2 ++ docs/contents/pg/connection.rst | 8 ++++---- docs/contents/pg/query.rst | 16 ---------------- pgquery.c | 12 ------------ tests/test_classic_connection.py | 18 +----------------- tests/test_classic_notification.py | 1 - 6 files changed, 7 insertions(+), 50 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index e2b68425..67408993 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -5,6 +5,8 @@ Version 6.0 (to be released) ---------------------------- - Removed support for Python versions older than 3.7 (released June 2017) and PostgreSQL older than version 10 (released October 2017). +- Removed deprecated function `pg.pgnotify()`. +- Removed the deprecated method `ntuples()` of the `pg.Query` object. - Modernized code and tools for development, testing, linting and building. Version 5.2.5 (2023-08-28) diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 1adf29d1..b175a2a0 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -114,10 +114,10 @@ result codes: if :meth:`Connection.query` returns `None`, the result-returning methods will return an empty string (`''`). It's still necessary to call a result-returning method until it returns `None`. -:meth:`Query.listfields`, :meth:`Query.fieldname`, :meth:`Query.fieldnum`, -and :meth:`Query.ntuples` only work after a call to a result-returning method -with a non-`None` return value. :meth:`Query.ntuples` returns only the number -of rows returned by the previous result-returning method. +:meth:`Query.listfields`, :meth:`Query.fieldname` and :meth:`Query.fieldnum` +only work after a call to a result-returning method with a non-``None`` return +value. Calling ``len()`` on a :class:`Query` object returns the number of rows +of the previous result-returning method. If multiple semi-colon-delimited statements are passed to :meth:`Connection.query`, only the results of the last statement are returned diff --git a/docs/contents/pg/query.rst b/docs/contents/pg/query.rst index 9e2998f8..3232c115 100644 --- a/docs/contents/pg/query.rst +++ b/docs/contents/pg/query.rst @@ -400,22 +400,6 @@ negative value if it is of variable size, and a type-specific modifier value. .. versionadded:: 5.2 -ntuples -- return number of tuples in query object --------------------------------------------------- - -.. method:: Query.ntuples() - - Return number of tuples in query object - - :returns: number of tuples in :class:`Query` - :rtype: int - :raises TypeError: Too many arguments. - -This method returns the number of tuples in the query result. - -.. deprecated:: 5.1 - You can use the normal :func:`len` function instead. - memsize -- return number of bytes allocated by query result ----------------------------------------------------------- diff --git a/pgquery.c b/pgquery.c index 1196889a..194bfaa1 100644 --- a/pgquery.c +++ b/pgquery.c @@ -260,16 +260,6 @@ query_memsize(queryObject *self, PyObject *noargs) #endif /* MEMORY_SIZE */ } -/* Get number of rows. */ -static char query_ntuples__doc__[] = -"ntuples() -- return number of tuples returned by query"; - -static PyObject * -query_ntuples(queryObject *self, PyObject *noargs) -{ - return PyLong_FromLong(self->max_row); -} - /* List field names from query result. */ static char query_listfields__doc__[] = "listfields() -- List field names from result"; @@ -948,8 +938,6 @@ static struct PyMethodDef query_methods[] = { METH_NOARGS, query_listfields__doc__}, {"fieldinfo", (PyCFunction) query_fieldinfo, METH_VARARGS, query_fieldinfo__doc__}, - {"ntuples", (PyCFunction) query_ntuples, - METH_NOARGS, query_ntuples__doc__}, {"memsize", (PyCFunction) query_memsize, METH_NOARGS, query_memsize__doc__}, {NULL, NULL} diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 4436239d..dc7311c4 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -268,7 +268,7 @@ def testAllQueryMembers(self): query = self.connection.query("select true where false") members = ''' dictiter dictresult fieldinfo fieldname fieldnum getresult - listfields memsize namediter namedresult ntuples + listfields memsize namediter namedresult one onedict onenamed onescalar scalariter scalarresult single singledict singlenamed singlescalar '''.split() @@ -712,22 +712,6 @@ def testFieldInfoName(self): self.assertRaises(IndexError, f, -1) self.assertRaises(IndexError, f, 4) - def testNtuples(self): # deprecated - q = "select 1 where false" - r = self.c.query(q).ntuples() - self.assertIsInstance(r, int) - self.assertEqual(r, 0) - q = ("select 1 as a, 2 as b, 3 as c, 4 as d" - " union select 5 as a, 6 as b, 7 as c, 8 as d") - r = self.c.query(q).ntuples() - self.assertIsInstance(r, int) - self.assertEqual(r, 2) - q = ("select 1 union select 2 union select 3" - " union select 4 union select 5 union select 6") - r = self.c.query(q).ntuples() - self.assertIsInstance(r, int) - self.assertEqual(r, 6) - def testLen(self): q = "select 1 where false" self.assertEqual(len(self.c.query(q)), 0) diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 12d0dee8..13a341dd 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -10,7 +10,6 @@ """ import unittest -import warnings from threading import Thread from time import sleep From 536805c83e6c5e1ff047d609e71a41eceeadaba9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 22:59:03 +0200 Subject: [PATCH 140/194] Make sure import statements are sorted --- pg.py | 1 - pyproject.toml | 1 + tests/test_dbapi20.py | 20 +++++++------------- tests/test_tutorial.py | 2 +- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/pg.py b/pg.py index 6d4db899..f8dfb1be 100644 --- a/pg.py +++ b/pg.py @@ -150,7 +150,6 @@ 'version', '__version__'] import select -import warnings import weakref from collections import OrderedDict, namedtuple from datetime import date, datetime, time, timedelta diff --git a/pyproject.toml b/pyproject.toml index b1a184cc..dfe59c2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ line-length = 79 select = [ "E", # pycodestyle "F", # pyflakes + "I", # isort "UP", # pyupgrade "D", # pydocstyle ] diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 8522fbc3..380caf52 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -2,19 +2,13 @@ import gc import unittest - -from datetime import date, time, datetime, timedelta, timezone +from datetime import date, datetime, time, timedelta, timezone from uuid import UUID as Uuid import pgdb -try: - from . import dbapi20 -except (ImportError, ValueError, SystemError): - # noinspection PyUnresolvedReferences - import dbapi20 - -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from . import dbapi20 +from .config import dbhost, dbname, dbpasswd, dbport, dbuser class PgBitString: @@ -27,7 +21,7 @@ def __pg_repr__(self): return f"B'{self.value:b}'" -class test_PyGreSQL(dbapi20.DatabaseAPI20Test): +class TestPgDb(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () @@ -38,7 +32,7 @@ class test_PyGreSQL(dbapi20.DatabaseAPI20Test): lower_func = 'lower' # For stored procedure test def setUp(self): - dbapi20.DatabaseAPI20Test.setUp(self) + super().setUp() try: con = self._connect() con.close() @@ -52,7 +46,7 @@ def setUp(self): db.query('create database ' + dbname) def tearDown(self): - dbapi20.DatabaseAPI20Test.tearDown(self) + super().tearDown() def test_version(self): v = pgdb.version @@ -542,7 +536,7 @@ def test_sqlstate(self): def test_float(self): nan, inf = float('nan'), float('inf') - from math import isnan, isinf + from math import isinf, isnan self.assertTrue(isnan(nan) and not isinf(nan)) self.assertTrue(isinf(inf) and not isnan(inf)) values = [0, 1, 0.03125, -42.53125, nan, inf, -inf, diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 1a43ab7d..3f76f39b 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -5,7 +5,7 @@ from pg import DB from pgdb import connect -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser class TestClassicTutorial(unittest.TestCase): From 56a034bde9f32d2d4d325eb6b139fdba0526d59f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 23:54:56 +0200 Subject: [PATCH 141/194] Use PEP8 naming conventions for test methods --- pgdb.py | 22 +- pyproject.toml | 1 + setup.py | 2 +- tests/dbapi20.py | 68 ++-- tests/test_classic_connection.py | 444 ++++++++++----------- tests/test_classic_dbwrapper.py | 604 +++++++++++++++-------------- tests/test_classic_functions.py | 118 +++--- tests/test_classic_largeobj.py | 52 +-- tests/test_classic_notification.py | 46 +-- tests/test_dbapi20.py | 12 +- 10 files changed, 687 insertions(+), 682 deletions(-) diff --git a/pgdb.py b/pgdb.py index f61522bb..5752ac4d 100644 --- a/pgdb.py +++ b/pgdb.py @@ -148,7 +148,7 @@ from math import isinf, isnan from re import compile as regex from time import localtime -from uuid import UUID as Uuid +from uuid import UUID as Uuid # noqa: N811 Decimal = StdDecimal @@ -729,7 +729,7 @@ def row_caster(row): return row_caster -class _quotedict(dict): +class _QuoteDict(dict): """Dictionary with auto quoting of its items. The quote attribute must be set to the desired quote function. @@ -897,7 +897,7 @@ def _quoteparams(self, string, parameters): except (TypeError, ValueError): return string # silently accept unescaped quotes if isinstance(parameters, dict): - parameters = _quotedict(parameters) + parameters = _QuoteDict(parameters) parameters.quote = self._quote else: parameters = tuple(map(self._quote, parameters)) @@ -1687,34 +1687,35 @@ def __ne__(self, other): # Mandatory type helpers defined by DB-API 2 specs: -def Date(year, month, day): +def Date(year, month, day): # noqa: N802 """Construct an object holding a date value.""" return date(year, month, day) -def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None): +def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None): # noqa: N802 """Construct an object holding a time value.""" return time(hour, minute, second, microsecond, tzinfo) -def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0, +def Timestamp(year, month, day, # noqa: N802 + hour=0, minute=0, second=0, microsecond=0, tzinfo=None): """Construct an object holding a time stamp value.""" return datetime(year, month, day, hour, minute, second, microsecond, tzinfo) -def DateFromTicks(ticks): +def DateFromTicks(ticks): # noqa: N802 """Construct an object holding a date value from the given ticks value.""" return Date(*localtime(ticks)[:3]) -def TimeFromTicks(ticks): +def TimeFromTicks(ticks): # noqa: N802 """Construct an object holding a time value from the given ticks value.""" return Time(*localtime(ticks)[3:6]) -def TimestampFromTicks(ticks): +def TimestampFromTicks(ticks): # noqa: N802 """Construct an object holding a time stamp from the given ticks value.""" return Timestamp(*localtime(ticks)[:6]) @@ -1725,7 +1726,8 @@ class Binary(bytes): # Additional type helpers for PyGreSQL: -def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0): +def Interval(days, # noqa: N802 + hours=0, minutes=0, seconds=0, microseconds=0): """Construct an object holding a time interval value.""" return timedelta(days, hours=hours, minutes=minutes, seconds=seconds, microseconds=microseconds) diff --git a/pyproject.toml b/pyproject.toml index dfe59c2f..9603b825 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ select = [ "E", # pycodestyle "F", # pyflakes "I", # isort + "N", # pep8-naming "UP", # pyupgrade "D", # pydocstyle ] diff --git a/setup.py b/setup.py index 29a84bf8..09c6e2f8 100755 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ def pg_version(): extra_compile_args = ['-O2', '-funsigned-char', '-Wall', '-Wconversion'] -class build_pg_ext(build_ext): +class build_pg_ext(build_ext): # noqa: N801 """Customized build_ext command for PyGreSQL.""" description = "build the PyGreSQL C extension" diff --git a/tests/dbapi20.py b/tests/dbapi20.py index e76e5fb9..bb913475 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -55,10 +55,10 @@ class mytest(dbapi20.DatabaseAPI20Test): # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. - def executeDDL1(self, cursor): + def execute_ddl1(self, cursor): cursor.execute(self.ddl1) - def executeDDL2(self, cursor): + def execute_ddl2(self, cursor): cursor.execute(self.ddl2) def setUp(self): @@ -134,7 +134,7 @@ def test_paramstyle(self): except AttributeError: self.fail("Driver doesn't define paramstyle") - def test_Exceptions(self): + def test_exceptions(self): # Make sure required exceptions exist, and are in the # defined hierarchy. sub = issubclass @@ -149,7 +149,7 @@ def test_Exceptions(self): self.assertTrue(sub(self.driver.ProgrammingError, self.driver.Error)) self.assertTrue(sub(self.driver.NotSupportedError, self.driver.Error)) - def test_ExceptionsAsConnectionAttributes(self): + def test_exceptions_as_connection_attributes(self): # OPTIONAL EXTENSION # Test for the optional DB API 2.0 extension, where the exceptions # are exposed as attributes on the Connection object @@ -202,7 +202,7 @@ def test_cursor_isolation(self): # the documented transaction isolation level cur1 = con.cursor() cur2 = con.cursor() - self.executeDDL1(cur1) + self.execute_ddl1(cur1) cur1.execute(f"{self.insert} into {self.table_prefix}booze" " values ('Victoria Bitter')") cur2.execute(f"select name from {self.table_prefix}booze") @@ -217,7 +217,7 @@ def test_description(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) + self.execute_ddl1(cur) self.assertIsNone( cur.description, 'cursor.description should be none after executing a' @@ -238,7 +238,7 @@ def test_description(self): f' Got: {cur.description[0][1]!r}') # Make sure self.description gets reset - self.executeDDL2(cur) + self.execute_ddl2(cur) self.assertIsNone( cur.description, 'cursor.description not being set to None when executing' @@ -250,7 +250,7 @@ def test_rowcount(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) + self.execute_ddl1(cur) self.assertIn( cur.rowcount, (-1, 0), # Bug #543885 'cursor.rowcount should be -1 or 0 after executing no-result' @@ -266,7 +266,7 @@ def test_rowcount(self): cur.rowcount, (-1, 1), 'cursor.rowcount should == number of rows returned, or' ' set to -1 after executing a select statement') - self.executeDDL2(cur) + self.execute_ddl2(cur) self.assertIn( cur.rowcount, (-1, 0), # Bug #543885 'cursor.rowcount should be -1 or 0 after executing no-result' @@ -303,7 +303,7 @@ def test_close(self): # cursor.execute should raise an Error if called after connection # closed - self.assertRaises(self.driver.Error, self.executeDDL1, cur) + self.assertRaises(self.driver.Error, self.execute_ddl1, cur) # connection.commit should raise an Error if called after connection' # closed.' @@ -325,7 +325,7 @@ def test_execute(self): con.close() def _paraminsert(self, cur): - self.executeDDL2(cur) + self.execute_ddl2(cur) table_prefix = self.table_prefix insert = f"{self.insert} into {table_prefix}barflys values" cur.execute( @@ -384,7 +384,7 @@ def test_executemany(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) + self.execute_ddl1(cur) table_prefix = self.table_prefix insert = f'{self.insert} into {table_prefix}booze values' largs = [("Cooper's",), ("Boag's",)] @@ -428,7 +428,7 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called after # executing a query that cannot return rows - self.executeDDL1(cur) + self.execute_ddl1(cur) self.assertRaises(self.driver.Error, cur.fetchone) cur.execute(f'select name from {self.table_prefix}booze') @@ -474,7 +474,7 @@ def test_next(self): # cursor.next should raise an Error if called after # executing a query that cannot return rows - self.executeDDL1(cur) + self.execute_ddl1(cur) self.assertRaises(self.driver.Error, cur.next) # cursor.next should return None if a query retrieves no rows @@ -527,7 +527,7 @@ def test_fetchmany(self): # issuing a query self.assertRaises(self.driver.Error, cur.fetchmany, 4) - self.executeDDL1(cur) + self.execute_ddl1(cur) for sql in self._populate(): cur.execute(sql) @@ -588,7 +588,7 @@ def test_fetchmany(self): ' called after the whole result set has been fetched') self.assertIn(cur.rowcount, (-1, 6)) - self.executeDDL2(cur) + self.execute_ddl2(cur) cur.execute(f'select name from {self.table_prefix}barflys') r = cur.fetchmany() # Should get empty sequence self.assertEqual( @@ -609,7 +609,7 @@ def test_fetchall(self): # as a select) self.assertRaises(self.driver.Error, cur.fetchall) - self.executeDDL1(cur) + self.execute_ddl1(cur) for sql in self._populate(): cur.execute(sql) @@ -635,7 +635,7 @@ def test_fetchall(self): ' after the whole result set has been fetched') self.assertIn(cur.rowcount, (-1, len(self.samples))) - self.executeDDL2(cur) + self.execute_ddl2(cur) cur.execute(f'select name from {self.table_prefix}barflys') rows = cur.fetchall() self.assertIn(cur.rowcount, (-1, 0)) @@ -651,7 +651,7 @@ def test_mixedfetch(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) + self.execute_ddl1(cur) for sql in self._populate(): cur.execute(sql) @@ -680,7 +680,7 @@ def test_mixedfetch(self): finally: con.close() - def help_nextset_setUp(self, cur): + def help_nextset_setup(self, cur): """Set up nextset test. Should create a procedure called deleteme that returns two result sets, @@ -696,7 +696,7 @@ def help_nextset_setUp(self, cur): # """ # cur.execute(sql) - def help_nextset_tearDown(self, cur): + def help_nextset_teardown(self, cur): """Clean up after nextset test. If cleaning up is needed after test_nextset. @@ -717,7 +717,7 @@ def test_nextset(self): # self.executeDDL1(cur) # for sql in self._populate(): # cur.execute(sql) - # self.help_nextset_setUp(cur) + # self.help_nextset_setup(cur) # cur.callproc('deleteme') # number_of_rows = cur.fetchone() # self.assertEqual(number_of_rows[0], len(self.samples)) @@ -727,7 +727,7 @@ def test_nextset(self): # self.assertIsNone( # cur.nextset(), 'No more return sets, should return None') # finally: - # self.help_nextset_tearDown(cur) + # self.help_nextset_teardown(cur) # finally: # con.close() @@ -765,11 +765,11 @@ def test_setoutputsize(self): # Real test for setoutputsize is driver dependant raise NotImplementedError('Driver needed to override this test') - def test_None(self): + def test_none(self): con = self._connect() try: cur = con.cursor() - self.executeDDL2(cur) + self.execute_ddl2(cur) # inserting NULL to the second column, because some drivers might # need the first one to be primary key, which means it needs # to have a non-NULL value @@ -783,21 +783,21 @@ def test_None(self): finally: con.close() - def test_Date(self): + def test_date(self): d1 = self.driver.Date(2002, 12, 25) d2 = self.driver.DateFromTicks( time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied self.assertEqual(str(d1), str(d2)) - def test_Time(self): + def test_time(self): t1 = self.driver.Time(13, 45, 30) t2 = self.driver.TimeFromTicks( time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied self.assertEqual(str(t1), str(t2)) - def test_Timestamp(self): + def test_timestamp(self): t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) t2 = self.driver.TimestampFromTicks( time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) @@ -805,26 +805,26 @@ def test_Timestamp(self): # Can we assume this? API doesn't specify, but it seems implied self.assertEqual(str(t1), str(t2)) - def test_Binary(self): + def test_binary_string(self): self.driver.Binary(b'Something') self.driver.Binary(b'') - def test_STRING(self): + def test_string_type(self): self.assertTrue(hasattr(self.driver, 'STRING'), 'module.STRING must be defined') - def test_BINARY(self): + def test_binary_type(self): self.assertTrue(hasattr(self.driver, 'BINARY'), 'module.BINARY must be defined.') - def test_NUMBER(self): + def test_number_type(self): self.assertTrue(hasattr(self.driver, 'NUMBER'), 'module.NUMBER must be defined.') - def test_DATETIME(self): + def test_datetime_type(self): self.assertTrue(hasattr(self.driver, 'DATETIME'), 'module.DATETIME must be defined.') - def test_ROWID(self): + def test_rowid_type(self): self.assertTrue(hasattr(self.driver, 'ROWID'), 'module.ROWID must be defined.') diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index dc7311c4..ed31bed8 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -48,7 +48,7 @@ def connect_nowait(): class TestCanConnect(unittest.TestCase): """Test whether a basic connection to PostgreSQL is possible.""" - def testCanConnect(self): + def test_can_connect(self): try: connection = connect() rc = connection.poll() @@ -65,7 +65,7 @@ def testCanConnect(self): except pg.Error: self.fail('Cannot close the database connection') - def testCanConnectNoWait(self): + def test_can_connect_no_wait(self): try: connection = connect_nowait() rc = connection.poll() @@ -104,21 +104,21 @@ def is_method(self, attribute): return False return callable(getattr(self.connection, attribute)) - def testClassName(self): + def test_class_name(self): self.assertEqual(self.connection.__class__.__name__, 'Connection') - def testModuleName(self): + def test_module_name(self): self.assertEqual(self.connection.__class__.__module__, 'pg') - def testStr(self): + def test_str(self): r = str(self.connection) self.assertTrue(r.startswith('= 120000: self.skipTest("database does not support tables with oids") query = self.c.query @@ -797,7 +797,7 @@ def testQueryWithOids(self): self.assertIsInstance(r, str) self.assertEqual(r, '5') - def testMemSize(self): + def test_mem_size(self): # noinspection PyUnresolvedReferences if pg.get_pqlib_version() < 120000: self.skipTest("pqlib does not support memsize()") @@ -823,21 +823,21 @@ def setUp(self): def tearDown(self): self.c.close() - def testGetresulAscii(self): + def test_getresul_ascii(self): result = 'Hello, world!' q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresulAscii(self): + def test_dictresul_ascii(self): result = 'Hello, world!' q = f"select '{result}' as greeting" v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultUtf8(self): + def test_getresult_utf8(self): result = 'Hello, wörld & мир!' q = f"select '{result}'" # pass the query as unicode @@ -853,7 +853,7 @@ def testGetresultUtf8(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultUtf8(self): + def test_dictresult_utf8(self): result = 'Hello, wörld & мир!' q = f"select '{result}' as greeting" try: @@ -867,7 +867,7 @@ def testDictresultUtf8(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultLatin1(self): + def test_getresult_latin1(self): try: self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): @@ -882,7 +882,7 @@ def testGetresultLatin1(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultLatin1(self): + def test_dictresult_latin1(self): try: self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): @@ -897,7 +897,7 @@ def testDictresultLatin1(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultCyrillic(self): + def test_getresult_cyrillic(self): try: self.c.query('set client_encoding=iso_8859_5') except (pg.DataError, pg.NotSupportedError): @@ -912,7 +912,7 @@ def testGetresultCyrillic(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultCyrillic(self): + def test_dictresult_cyrillic(self): try: self.c.query('set client_encoding=iso_8859_5') except (pg.DataError, pg.NotSupportedError): @@ -927,7 +927,7 @@ def testDictresultCyrillic(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultLatin9(self): + def test_getresult_latin9(self): try: self.c.query('set client_encoding=latin9') except (pg.DataError, pg.NotSupportedError): @@ -942,7 +942,7 @@ def testGetresultLatin9(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultLatin9(self): + def test_dictresult_latin9(self): try: self.c.query('set client_encoding=latin9') except (pg.DataError, pg.NotSupportedError): @@ -968,7 +968,7 @@ def setUp(self): def tearDown(self): self.c.close() - def testQueryWithNoneParam(self): + def test_query_with_none_param(self): self.assertRaises(TypeError, self.c.query, "select $1", None) self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None) self.assertEqual( @@ -978,8 +978,9 @@ def testQueryWithNoneParam(self): self.assertEqual( self.c.query("select $1::text", [[None]]).getresult(), [(None,)]) - def testQueryWithBoolParams(self, bool_enabled=None): + def test_query_with_bool_params(self, bool_enabled=None): query = self.c.query + bool_enabled_default = None if bool_enabled is not None: bool_enabled_default = pg.get_bool() pg.set_bool(bool_enabled) @@ -1003,13 +1004,12 @@ def testQueryWithBoolParams(self, bool_enabled=None): self.assertEqual(query(q, (True,)).getresult(), r_true) finally: if bool_enabled is not None: - # noinspection PyUnboundLocalVariable pg.set_bool(bool_enabled_default) - def testQueryWithBoolParamsNotDefault(self): - self.testQueryWithBoolParams(bool_enabled=not pg.get_bool()) + def test_query_with_bool_params_not_default(self): + self.test_query_with_bool_params(bool_enabled=not pg.get_bool()) - def testQueryWithIntParams(self): + def test_query_with_int_params(self): query = self.c.query self.assertEqual(query("select 1+1").getresult(), [(2,)]) self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)]) @@ -1031,7 +1031,7 @@ def testQueryWithIntParams(self): query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))).getresult(), [(15,)]) - def testQueryWithStrParams(self): + def test_query_with_str_params(self): query = self.c.query self.assertEqual( query("select $1||', world!'", ('Hello',)).getresult(), @@ -1064,7 +1064,7 @@ def testQueryWithStrParams(self): ('Hello', 'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)]) - def testQueryWithUnicodeParams(self): + def test_query_with_unicode_params(self): query = self.c.query try: query('set client_encoding=utf8') @@ -1076,7 +1076,7 @@ def testQueryWithUnicodeParams(self): query("select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult(), [('Hello, wörld!',)]) - def testQueryWithUnicodeParamsLatin1(self): + def test_query_with_unicode_params_latin1(self): query = self.c.query try: query('set client_encoding=latin1') @@ -1101,7 +1101,7 @@ def testQueryWithUnicodeParamsLatin1(self): UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', 'wörld')) - def testQueryWithUnicodeParamsCyrillic(self): + def test_query_with_unicode_params_cyrillic(self): query = self.c.query try: query('set client_encoding=iso_8859_5') @@ -1120,7 +1120,7 @@ def testQueryWithUnicodeParamsCyrillic(self): UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', 'мир!')) - def testQueryWithMixedParams(self): + def test_query_with_mixed_params(self): self.assertEqual( self.c.query( "select $1+2,$2||', world!'", (1, 'Hello')).getresult(), @@ -1131,17 +1131,17 @@ def testQueryWithMixedParams(self): (4711, None, 'Hello!')).getresult(), [(4711, None, 'Hello!')]) - def testQueryWithDuplicateParams(self): + def test_query_with_duplicate_params(self): self.assertRaises( pg.ProgrammingError, self.c.query, "select $1+$1", (1,)) self.assertRaises( pg.ProgrammingError, self.c.query, "select $1+$1", (1, 2)) - def testQueryWithZeroParams(self): + def test_query_with_zero_params(self): self.assertEqual( self.c.query("select 1+1", []).getresult(), [(2,)]) - def testQueryWithGarbage(self): + def test_query_with_garbage(self): garbage = r"'\{}+()-#[]oo324" self.assertEqual( self.c.query("select $1::text AS garbage", @@ -1159,38 +1159,38 @@ def setUp(self): def tearDown(self): self.c.close() - def testEmptyPreparedStatement(self): + def test_empty_prepared_statement(self): self.c.prepare('', '') self.assertRaises(ValueError, self.c.query_prepared, '') - def testInvalidPreparedStatement(self): + def test_invalid_prepared_statement(self): self.assertRaises(pg.ProgrammingError, self.c.prepare, '', 'bad') - def testDuplicatePreparedStatement(self): + def test_duplicate_prepared_statement(self): self.assertIsNone(self.c.prepare('q', 'select 1')) self.assertRaises(pg.ProgrammingError, self.c.prepare, 'q', 'select 2') - def testNonExistentPreparedStatement(self): + def test_non_existent_prepared_statement(self): self.assertRaises( pg.OperationalError, self.c.query_prepared, 'does-not-exist') - def testUnnamedQueryWithoutParams(self): + def test_unnamed_query_without_params(self): self.assertIsNone(self.c.prepare('', "select 'anon'")) self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)]) self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)]) - def testNamedQueryWithoutParams(self): + def test_named_query_without_params(self): self.assertIsNone(self.c.prepare('hello', "select 'world'")) self.assertEqual( self.c.query_prepared('hello').getresult(), [('world',)]) - def testMultipleNamedQueriesWithoutParams(self): + def test_multiple_named_queries_without_params(self): self.assertIsNone(self.c.prepare('query17', "select 17")) self.assertIsNone(self.c.prepare('query42', "select 42")) self.assertEqual(self.c.query_prepared('query17').getresult(), [(17,)]) self.assertEqual(self.c.query_prepared('query42').getresult(), [(42,)]) - def testUnnamedQueryWithParams(self): + def test_unnamed_query_with_params(self): self.assertIsNone(self.c.prepare('', "select $1 || ', ' || $2")) self.assertEqual( self.c.query_prepared('', ['hello', 'world']).getresult(), @@ -1199,7 +1199,7 @@ def testUnnamedQueryWithParams(self): self.assertEqual( self.c.query_prepared('', [17, -5, 29]).getresult(), [(42,)]) - def testMultipleNamedQueriesWithParams(self): + def test_multiple_named_queries_with_params(self): self.assertIsNone(self.c.prepare('q1', "select $1 || '!'")) self.assertIsNone(self.c.prepare('q2', "select $1 || '-' || $2")) self.assertEqual( @@ -1209,21 +1209,21 @@ def testMultipleNamedQueriesWithParams(self): self.c.query_prepared('q2', ['he', 'lo']).getresult(), [('he-lo',)]) - def testDescribeNonExistentQuery(self): + def test_describe_non_existent_query(self): self.assertRaises( pg.OperationalError, self.c.describe_prepared, 'does-not-exist') - def testDescribeUnnamedQuery(self): + def test_describe_unnamed_query(self): self.c.prepare('', "select 1::int, 'a'::char") r = self.c.describe_prepared('') self.assertEqual(r.listfields(), ('int4', 'bpchar')) - def testDescribeNamedQuery(self): + def test_describe_named_query(self): self.c.prepare('myquery', "select 1 as first, 2 as second") r = self.c.describe_prepared('myquery') self.assertEqual(r.listfields(), ('first', 'second')) - def testDescribeMultipleNamedQueries(self): + def test_describe_multiple_named_queries(self): self.c.prepare('query1', "select 1::int") self.c.prepare('query2', "select 1::int, 2::int") r = self.c.describe_prepared('query1') @@ -1267,36 +1267,36 @@ def assert_proper_cast(self, value, pgtype, pytype): self.assertEqual(len(r), 1) self.assertIsInstance(r[0], pytype) - def testInt(self): + def test_int(self): self.assert_proper_cast(0, 'int', int) self.assert_proper_cast(0, 'smallint', int) self.assert_proper_cast(0, 'oid', int) self.assert_proper_cast(0, 'cid', int) self.assert_proper_cast(0, 'xid', int) - def testLong(self): + def test_long(self): self.assert_proper_cast(0, 'bigint', int) - def testFloat(self): + def test_float(self): self.assert_proper_cast(0, 'float', float) self.assert_proper_cast(0, 'real', float) self.assert_proper_cast(0, 'double precision', float) self.assert_proper_cast('infinity', 'float', float) - def testNumeric(self): + def test_numeric(self): decimal = pg.get_decimal() self.assert_proper_cast(decimal(0), 'numeric', decimal) self.assert_proper_cast(decimal(0), 'decimal', decimal) - def testMoney(self): + def test_money(self): decimal = pg.get_decimal() self.assert_proper_cast(decimal('0'), 'money', decimal) - def testBool(self): + def test_bool(self): bool_type = bool if pg.get_bool() else str self.assert_proper_cast('f', 'bool', bool_type) - def testDate(self): + def test_date(self): self.assert_proper_cast('1956-01-31', 'date', str) self.assert_proper_cast('10:20:30', 'interval', str) self.assert_proper_cast('08:42:15', 'time', str) @@ -1304,16 +1304,16 @@ def testDate(self): self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str) self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str) - def testText(self): + def test_text(self): self.assert_proper_cast('', 'text', str) self.assert_proper_cast('', 'char', str) self.assert_proper_cast('', 'bpchar', str) self.assert_proper_cast('', 'varchar', str) - def testBytea(self): + def test_bytea(self): self.assert_proper_cast('', 'bytea', bytes) - def testJson(self): + def test_json(self): self.assert_proper_cast('{}', 'json', dict) @@ -1326,27 +1326,27 @@ def setUp(self): def tearDown(self): self.c.close() - def testLen(self): + def test_len(self): r = self.c.query("select generate_series(3,7)") self.assertEqual(len(r), 5) - def testGetItem(self): + def test_get_item(self): r = self.c.query("select generate_series(7,9)") self.assertEqual(r[0], (7,)) self.assertEqual(r[1], (8,)) self.assertEqual(r[2], (9,)) - def testGetItemWithNegativeIndex(self): + def test_get_item_with_negative_index(self): r = self.c.query("select generate_series(7,9)") self.assertEqual(r[-1], (9,)) self.assertEqual(r[-2], (8,)) self.assertEqual(r[-3], (7,)) - def testGetItemOutOfRange(self): + def test_get_item_out_of_range(self): r = self.c.query("select generate_series(7,9)") self.assertRaises(IndexError, r.__getitem__, 3) - def testIterate(self): + def test_iterate(self): r = self.c.query("select generate_series(3,5)") self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1354,29 +1354,29 @@ def testIterate(self): # noinspection PyUnresolvedReferences self.assertIsInstance(r[1], tuple) - def testIterateTwice(self): + def test_iterate_twice(self): r = self.c.query("select generate_series(3,5)") for i in range(2): self.assertEqual(list(r), [(3,), (4,), (5,)]) - def testIterateTwoColumns(self): + def test_iterate_two_columns(self): r = self.c.query("select 1,2 union select 3,4") self.assertIsInstance(r, Iterable) self.assertEqual(list(r), [(1, 2), (3, 4)]) - def testNext(self): + def test_next(self): r = self.c.query("select generate_series(7,9)") self.assertEqual(next(r), (7,)) self.assertEqual(next(r), (8,)) self.assertEqual(next(r), (9,)) self.assertRaises(StopIteration, next, r) - def testContains(self): + def test_contains(self): r = self.c.query("select generate_series(7,9)") self.assertIn((8,), r) self.assertNotIn((5,), r) - def testDictIterate(self): + def test_dict_iterate(self): r = self.c.query("select generate_series(3,5) as n").dictiter() self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1384,7 +1384,7 @@ def testDictIterate(self): self.assertEqual(r, [dict(n=3), dict(n=4), dict(n=5)]) self.assertIsInstance(r[1], dict) - def testDictIterateTwoColumns(self): + def test_dict_iterate_two_columns(self): r = self.c.query( "select 1 as one, 2 as two" " union select 3 as one, 4 as two").dictiter() @@ -1392,19 +1392,19 @@ def testDictIterateTwoColumns(self): r = list(r) self.assertEqual(r, [dict(one=1, two=2), dict(one=3, two=4)]) - def testDictNext(self): + def test_dict_next(self): r = self.c.query("select generate_series(7,9) as n").dictiter() self.assertEqual(next(r), dict(n=7)) self.assertEqual(next(r), dict(n=8)) self.assertEqual(next(r), dict(n=9)) self.assertRaises(StopIteration, next, r) - def testDictContains(self): + def test_dict_contains(self): r = self.c.query("select generate_series(7,9) as n").dictiter() self.assertIn(dict(n=8), r) self.assertNotIn(dict(n=5), r) - def testNamedIterate(self): + def test_named_iterate(self): r = self.c.query("select generate_series(3,5) as number").namediter() self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1414,7 +1414,7 @@ def testNamedIterate(self): self.assertEqual(r[1]._fields, ('number',)) self.assertEqual(r[1].number, 4) - def testNamedIterateTwoColumns(self): + def test_named_iterate_two_columns(self): r = self.c.query( "select 1 as one, 2 as two" " union select 3 as one, 4 as two").namediter() @@ -1426,7 +1426,7 @@ def testNamedIterateTwoColumns(self): self.assertEqual(r[1]._fields, ('one', 'two')) self.assertEqual(r[1].two, 4) - def testNamedNext(self): + def test_named_next(self): r = self.c.query("select generate_series(7,9) as number").namediter() self.assertEqual(next(r), (7,)) self.assertEqual(next(r), (8,)) @@ -1435,12 +1435,12 @@ def testNamedNext(self): self.assertEqual(n.number, 9) self.assertRaises(StopIteration, next, r) - def testNamedContains(self): + def test_named_contains(self): r = self.c.query("select generate_series(7,9)").namediter() self.assertIn((8,), r) self.assertNotIn((5,), r) - def testScalarIterate(self): + def test_scalar_iterate(self): r = self.c.query("select generate_series(3,5)").scalariter() self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1448,20 +1448,20 @@ def testScalarIterate(self): self.assertEqual(r, [3, 4, 5]) self.assertIsInstance(r[1], int) - def testScalarIterateTwoColumns(self): + def test_scalar_iterate_two_columns(self): r = self.c.query("select 1, 2 union select 3, 4").scalariter() self.assertIsInstance(r, Iterable) r = list(r) self.assertEqual(r, [1, 3]) - def testScalarNext(self): + def test_scalar_next(self): r = self.c.query("select generate_series(7,9)").scalariter() self.assertEqual(next(r), 7) self.assertEqual(next(r), 8) self.assertEqual(next(r), 9) self.assertRaises(StopIteration, next, r) - def testScalarContains(self): + def test_scalar_contains(self): r = self.c.query("select generate_series(7,9)").scalariter() self.assertIn(8, r) self.assertNotIn(5, r) @@ -1476,46 +1476,46 @@ def setUp(self): def tearDown(self): self.c.close() - def testOneWithEmptyQuery(self): + def test_one_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.one()) - def testOneWithSingleRow(self): + def test_one_with_single_row(self): q = self.c.query("select 1, 2") r = q.one() self.assertIsInstance(r, tuple) self.assertEqual(r, (1, 2)) self.assertEqual(q.one(), None) - def testOneWithTwoRows(self): + def test_one_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") self.assertEqual(q.one(), (1, 2)) self.assertEqual(q.one(), (3, 4)) self.assertEqual(q.one(), None) - def testOneDictWithEmptyQuery(self): + def test_one_dict_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.onedict()) - def testOneDictWithSingleRow(self): + def test_one_dict_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.onedict() self.assertIsInstance(r, dict) self.assertEqual(r, dict(one=1, two=2)) self.assertEqual(q.onedict(), None) - def testOneDictWithTwoRows(self): + def test_one_dict_with_two_rows(self): q = self.c.query( "select 1 as one, 2 as two union select 3 as one, 4 as two") self.assertEqual(q.onedict(), dict(one=1, two=2)) self.assertEqual(q.onedict(), dict(one=3, two=4)) self.assertEqual(q.onedict(), None) - def testOneNamedWithEmptyQuery(self): + def test_one_named_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.onenamed()) - def testOneNamedWithSingleRow(self): + def test_one_named_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.onenamed() self.assertEqual(r._fields, ('one', 'two')) @@ -1524,7 +1524,7 @@ def testOneNamedWithSingleRow(self): self.assertEqual(r, (1, 2)) self.assertEqual(q.onenamed(), None) - def testOneNamedWithTwoRows(self): + def test_one_named_with_two_rows(self): q = self.c.query( "select 1 as one, 2 as two union select 3 as one, 4 as two") r = q.onenamed() @@ -1539,24 +1539,24 @@ def testOneNamedWithTwoRows(self): self.assertEqual(r, (3, 4)) self.assertEqual(q.onenamed(), None) - def testOneScalarWithEmptyQuery(self): + def test_one_scalar_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.onescalar()) - def testOneScalarWithSingleRow(self): + def test_one_scalar_with_single_row(self): q = self.c.query("select 1, 2") r = q.onescalar() self.assertIsInstance(r, int) self.assertEqual(r, 1) self.assertEqual(q.onescalar(), None) - def testOneScalarWithTwoRows(self): + def test_one_scalar_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") self.assertEqual(q.onescalar(), 1) self.assertEqual(q.onescalar(), 3) self.assertEqual(q.onescalar(), None) - def testSingleWithEmptyQuery(self): + def test_single_with_empty_query(self): q = self.c.query("select 0 where false") try: q.single() @@ -1567,7 +1567,7 @@ def testSingleWithEmptyQuery(self): self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleWithSingleRow(self): + def test_single_with_single_row(self): q = self.c.query("select 1, 2") r = q.single() self.assertIsInstance(r, tuple) @@ -1576,7 +1576,7 @@ def testSingleWithSingleRow(self): self.assertIsInstance(r, tuple) self.assertEqual(r, (1, 2)) - def testSingleWithTwoRows(self): + def test_single_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.single() @@ -1587,7 +1587,7 @@ def testSingleWithTwoRows(self): self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testSingleDictWithEmptyQuery(self): + def test_single_dict_with_empty_query(self): q = self.c.query("select 0 where false") try: q.singledict() @@ -1598,7 +1598,7 @@ def testSingleDictWithEmptyQuery(self): self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleDictWithSingleRow(self): + def test_single_dict_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.singledict() self.assertIsInstance(r, dict) @@ -1607,7 +1607,7 @@ def testSingleDictWithSingleRow(self): self.assertIsInstance(r, dict) self.assertEqual(r, dict(one=1, two=2)) - def testSingleDictWithTwoRows(self): + def test_single_dict_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.singledict() @@ -1618,7 +1618,7 @@ def testSingleDictWithTwoRows(self): self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testSingleNamedWithEmptyQuery(self): + def test_single_named_with_empty_query(self): q = self.c.query("select 0 where false") try: q.singlenamed() @@ -1629,7 +1629,7 @@ def testSingleNamedWithEmptyQuery(self): self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleNamedWithSingleRow(self): + def test_single_named_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.singlenamed() self.assertEqual(r._fields, ('one', 'two')) @@ -1642,7 +1642,7 @@ def testSingleNamedWithSingleRow(self): self.assertEqual(r.two, 2) self.assertEqual(r, (1, 2)) - def testSingleNamedWithTwoRows(self): + def test_single_named_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.singlenamed() @@ -1653,7 +1653,7 @@ def testSingleNamedWithTwoRows(self): self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testSingleScalarWithEmptyQuery(self): + def test_single_scalar_with_empty_query(self): q = self.c.query("select 0 where false") try: q.singlescalar() @@ -1664,7 +1664,7 @@ def testSingleScalarWithEmptyQuery(self): self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleScalarWithSingleRow(self): + def test_single_scalar_with_single_row(self): q = self.c.query("select 1, 2") r = q.singlescalar() self.assertIsInstance(r, int) @@ -1673,7 +1673,7 @@ def testSingleScalarWithSingleRow(self): self.assertIsInstance(r, int) self.assertEqual(r, 1) - def testSingleScalarWithTwoRows(self): + def test_single_scalar_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.singlescalar() @@ -1684,13 +1684,13 @@ def testSingleScalarWithTwoRows(self): self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testScalarResult(self): + def test_scalar_result(self): q = self.c.query("select 1, 2 union select 3, 4") r = q.scalarresult() self.assertIsInstance(r, list) self.assertEqual(r, [1, 3]) - def testScalarIter(self): + def test_scalar_iter(self): q = self.c.query("select 1, 2 union select 3, 4") r = q.scalariter() self.assertNotIsInstance(r, (list, tuple)) @@ -1809,22 +1809,22 @@ def get_back(self, encoding='utf-8'): data.append(row) return data - def testInserttable1Row(self): + def test_inserttable1_row(self): data = self.data[2:3] self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttable4Rows(self): + def test_inserttable4_rows(self): data = self.data self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableFromTupleOfLists(self): + def test_inserttable_from_tuple_of_lists(self): data = tuple(list(row) for row in self.data) self.c.inserttable('test', data) self.assertEqual(self.get_back(), self.data) - def testInserttableWithDifferentRowSizes(self): + def test_inserttable_with_different_row_sizes(self): data = self.data[:-1] + [self.data[-1][:-1]] try: self.c.inserttable('test', data) @@ -1834,34 +1834,34 @@ def testInserttableWithDifferentRowSizes(self): else: self.assertFalse('expected an error') - def testInserttableFromSetofTuples(self): + def test_inserttable_from_setof_tuples(self): data = {row for row in self.data} self.c.inserttable('test', data) self.assertEqual(self.get_back(), self.data) - def testInserttableFromDictAsInterable(self): + def test_inserttable_from_dict_as_interable(self): data = {row: None for row in self.data} self.c.inserttable('test', data) self.assertEqual(self.get_back(), self.data) - def testInserttableFromDictKeys(self): + def test_inserttable_from_dict_keys(self): data = {row: None for row in self.data} keys = data.keys() self.c.inserttable('test', keys) self.assertEqual(self.get_back(), self.data) - def testInserttableFromDictValues(self): + def test_inserttable_from_dict_values(self): data = {i: row for i, row in enumerate(self.data)} values = data.values() self.c.inserttable('test', values) self.assertEqual(self.get_back(), self.data) - def testInserttableFromGeneratorOfTuples(self): + def test_inserttable_from_generator_of_tuples(self): data = (row for row in self.data) self.c.inserttable('test', data) self.assertEqual(self.get_back(), self.data) - def testInserttableFromListOfSets(self): + def test_inserttable_from_list_of_sets(self): data = [set(row) for row in self.data] try: self.c.inserttable('test', data) @@ -1871,14 +1871,14 @@ def testInserttableFromListOfSets(self): else: self.assertFalse('expected an error') - def testInserttableMultipleRows(self): + def test_inserttable_multiple_rows(self): num_rows = 100 data = self.data[2:3] * num_rows self.c.inserttable('test', data) r = self.c.query("select count(*) from test").getresult()[0][0] self.assertEqual(r, num_rows) - def testInserttableMultipleCalls(self): + def test_inserttable_multiple_calls(self): num_rows = 10 data = self.data[2:3] for _i in range(num_rows): @@ -1886,23 +1886,23 @@ def testInserttableMultipleCalls(self): r = self.c.query("select count(*) from test").getresult()[0][0] self.assertEqual(r, num_rows) - def testInserttableNullValues(self): + def test_inserttable_null_values(self): data = [(None,) * 14] * 100 self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableNoColumn(self): + def test_inserttable_no_column(self): data = [()] * 10 self.c.inserttable('test', data, []) self.assertEqual(self.get_back(), []) - def testInserttableOnlyOneColumn(self): + def test_inserttable_only_one_column(self): data = [(42,)] * 50 self.c.inserttable('test', data, ['i4']) data = [tuple([42 if i == 1 else None for i in range(14)])] * 50 self.assertEqual(self.get_back(), data) - def testInserttableOnlyTwoColumns(self): + def test_inserttable_only_two_columns(self): data = [(bool(i % 2), i * .5) for i in range(20)] self.c.inserttable('test', data, ('b', 'f4')) # noinspection PyTypeChecker @@ -1910,12 +1910,12 @@ def testInserttableOnlyTwoColumns(self): + (None,) * 6 for i in range(20)] self.assertEqual(self.get_back(), data) - def testInserttableWithDottedTableName(self): + def test_inserttable_with_dotted_table_name(self): data = self.data self.c.inserttable('public.test', data) self.assertEqual(self.get_back(), data) - def testInserttableWithInvalidTableName(self): + def test_inserttable_with_invalid_table_name(self): data = [(42,)] # check that the table name is not inserted unescaped # (this would pass otherwise since there is a column named i4) @@ -1928,7 +1928,7 @@ def testInserttableWithInvalidTableName(self): # make sure that it works if parameters are passed properly self.c.inserttable('test', data, ['i4']) - def testInserttableWithInvalidDataType(self): + def test_inserttable_with_invalid_data_type(self): try: self.c.inserttable('test', 42) except TypeError as e: @@ -1936,7 +1936,7 @@ def testInserttableWithInvalidDataType(self): else: self.assertFalse('expected an error') - def testInserttableWithInvalidColumnName(self): + def test_inserttable_with_invalid_column_name(self): data = [(2, 4)] # check that the column names are not inserted unescaped # (this would pass otherwise since there are columns i2 and i4) @@ -1950,7 +1950,7 @@ def testInserttableWithInvalidColumnName(self): # make sure that it works if parameters are passed properly self.c.inserttable('test', data, ['i2', 'i4']) - def testInserttableWithInvalidColumList(self): + def test_inserttable_with_invalid_colum_list(self): data = self.data try: self.c.inserttable('test', data, 'invalid') @@ -1960,7 +1960,7 @@ def testInserttableWithInvalidColumList(self): else: self.assertFalse('expected an error') - def testInserttableWithHugeListOfColumnNames(self): + def test_inserttable_with_huge_list_of_column_names(self): data = self.data # try inserting data with a huge list of column names cols = ['very_long_column_name'] * 2000 @@ -1970,13 +1970,13 @@ def testInserttableWithHugeListOfColumnNames(self): cols *= 2 self.assertRaises(MemoryError, self.c.inserttable, 'test', data, cols) - def testInserttableWithOutOfRangeData(self): + def test_inserttable_with_out_of_range_data(self): # try inserting data out of range for the column type # Should raise a value error because of smallint out of range self.assertRaises( ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) - def testInserttableMaxValues(self): + def test_inserttable_max_values(self): data = [(2 ** 15 - 1, 2 ** 31 - 1, 2 ** 31 - 1, True, '2999-12-31', '11:59:59', 1e99, 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, @@ -1984,7 +1984,7 @@ def testInserttableMaxValues(self): self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableByteValues(self): + def test_inserttable_byte_values(self): try: self.c.query("select '€', 'käse', 'сыр', 'pont-l''évêque'") except pg.DataError: @@ -2003,7 +2003,7 @@ def testInserttableByteValues(self): data = [row_unicode] * 2 self.assertEqual(self.get_back(), data) - def testInserttableUnicodeUtf8(self): + def test_inserttable_unicode_utf8(self): try: self.c.query("select '€', 'käse', 'сыр', 'pont-l''évêque'") except pg.DataError: @@ -2018,7 +2018,7 @@ def testInserttableUnicodeUtf8(self): self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableUnicodeLatin1(self): + def test_inserttable_unicode_latin1(self): try: self.c.query("set client_encoding=latin1") self.c.query("select '¥'") @@ -2040,7 +2040,7 @@ def testInserttableUnicodeLatin1(self): self.c.inserttable('test', data) self.assertEqual(self.get_back('latin1'), data) - def testInserttableUnicodeLatin9(self): + def test_inserttable_unicode_latin9(self): try: self.c.query("set client_encoding=latin9") self.c.query("select '€'") @@ -2057,7 +2057,7 @@ def testInserttableUnicodeLatin9(self): self.c.inserttable('test', data) self.assertEqual(self.get_back('latin9'), data) - def testInserttableNoEncoding(self): + def test_inserttable_no_encoding(self): self.c.query("set client_encoding=sql_ascii") # non-ascii chars do not fit in char(1) when there is no encoding c = '€' if self.has_encoding else '$' @@ -2069,7 +2069,7 @@ def testInserttableNoEncoding(self): # cannot encode non-ascii unicode without a specific encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) - def testInserttableFromQuery(self): + def test_inserttable_from_query(self): data = self.c.query( "select 2::int2 as i2, 4::int4 as i4, 8::int8 as i8, true as b," "null as dt, null as ti, null as d," @@ -2080,7 +2080,7 @@ def testInserttableFromQuery(self): (2, 4, 8, True, None, None, None, 4.5, 8.5, None, 'c', 'v4', None, 'text')]) - def testInserttableSpecialChars(self): + def test_inserttable_special_chars(self): class S: def __repr__(self): return s @@ -2093,7 +2093,7 @@ def __repr__(self): self.assertEqual( self.c.query('select t from test').getresult(), [(s,)] * 3) - def testInsertTableBigRowSize(self): + def test_insert_table_big_row_size(self): # inserting rows with a size of up to 64k bytes should work t = '*' * 50000 data = [(t,)] @@ -2105,7 +2105,7 @@ def testInsertTableBigRowSize(self): data = [(t,)] self.assertRaises(MemoryError, self.c.inserttable, 'test', data, ['t']) - def testInsertTableSmallIntOverflow(self): + def test_insert_table_small_int_overflow(self): rest_row = self.data[2][1:] data = [(32000,) + rest_row] self.c.inserttable('test', data) @@ -2148,7 +2148,7 @@ def tearDown(self): self.c.query("truncate table test") self.c.close() - def testPutline(self): + def test_putline(self): putline = self.c.putline query = self.c.query data = list(enumerate("apple pear plum cherry banana".split())) @@ -2161,7 +2161,7 @@ def testPutline(self): r = query("select * from test").getresult() self.assertEqual(r, data) - def testPutlineBytesAndUnicode(self): + def test_putline_bytes_and_unicode(self): putline = self.c.putline query = self.c.query try: @@ -2177,7 +2177,7 @@ def testPutlineBytesAndUnicode(self): r = query("select * from test").getresult() self.assertEqual(r, [(47, 'käse'), (35, 'würstel')]) - def testGetline(self): + def test_getline(self): getline = self.c.getline query = self.c.query data = list(enumerate("apple banana pear plum strawberry".split())) @@ -2198,7 +2198,7 @@ def testGetline(self): except OSError: pass - def testGetlineBytesAndUnicode(self): + def test_getline_bytes_and_unicode(self): getline = self.c.getline query = self.c.query try: @@ -2222,7 +2222,7 @@ def testGetlineBytesAndUnicode(self): except OSError: pass - def testParameterChecks(self): + def test_parameter_checks(self): self.assertRaises(TypeError, self.c.putline) self.assertRaises(TypeError, self.c.getline, 'invalid') self.assertRaises(TypeError, self.c.endcopy, 'invalid') @@ -2238,7 +2238,7 @@ def tearDown(self): self.doCleanups() self.c.close() - def testGetNotify(self): + def test_get_notify(self): getnotify = self.c.getnotify query = self.c.query self.assertIsNone(getnotify()) @@ -2268,23 +2268,23 @@ def testGetNotify(self): finally: query('unlisten test_notify') - def testGetNoticeReceiver(self): + def test_get_notice_receiver(self): self.assertIsNone(self.c.get_notice_receiver()) - def testSetNoticeReceiver(self): + def test_set_notice_receiver(self): self.assertRaises(TypeError, self.c.set_notice_receiver, 42) self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid') self.assertIsNone(self.c.set_notice_receiver(lambda notice: None)) self.assertIsNone(self.c.set_notice_receiver(None)) - def testSetAndGetNoticeReceiver(self): + def test_set_and_get_notice_receiver(self): r = lambda notice: None # noqa: E731 self.assertIsNone(self.c.set_notice_receiver(r)) self.assertIs(self.c.get_notice_receiver(), r) self.assertIsNone(self.c.set_notice_receiver(None)) self.assertIsNone(self.c.get_notice_receiver()) - def testNoticeReceiver(self): + def test_notice_receiver(self): self.addCleanup(self.c.query, 'drop function bilbo_notice();') self.c.query('''create function bilbo_notice() returns void AS $$ begin @@ -2326,7 +2326,7 @@ def setUp(self): def tearDown(self): self.c.close() - def testGetDecimalPoint(self): + def test_get_decimal_point(self): point = pg.get_decimal_point() # error if a parameter is passed self.assertRaises(TypeError, pg.get_decimal_point, point) @@ -2359,7 +2359,7 @@ def testGetDecimalPoint(self): pg.set_decimal_point(point) self.assertIsNone(r) - def testSetDecimalPoint(self): + def test_set_decimal_point(self): d = pg.Decimal point = pg.get_decimal_point() self.assertRaises(TypeError, pg.set_decimal_point) @@ -2483,7 +2483,7 @@ def testSetDecimalPoint(self): pg.set_decimal_point(point) self.assertEqual(r, bad_money) - def testGetDecimal(self): + def test_get_decimal(self): decimal_class = pg.get_decimal() # error if a parameter is passed self.assertRaises(TypeError, pg.get_decimal, decimal_class) @@ -2497,7 +2497,7 @@ def testGetDecimal(self): r = pg.get_decimal() self.assertIs(r, decimal_class) - def testSetDecimal(self): + def test_set_decimal(self): decimal_class = pg.get_decimal() # error if no parameter is passed self.assertRaises(TypeError, pg.set_decimal) @@ -2520,7 +2520,7 @@ def testSetDecimal(self): self.assertIsInstance(r, int) self.assertEqual(r, 3425) - def testGetBool(self): + def test_get_bool(self): use_bool = pg.get_bool() # error if a parameter is passed self.assertRaises(TypeError, pg.get_bool, use_bool) @@ -2555,7 +2555,7 @@ def testGetBool(self): self.assertIsInstance(r, bool) self.assertIs(r, True) - def testSetBool(self): + def test_set_bool(self): use_bool = pg.get_bool() # error if no parameter is passed self.assertRaises(TypeError, pg.set_bool) @@ -2583,7 +2583,7 @@ def testSetBool(self): self.assertIsInstance(r, bool) self.assertIs(r, True) - def testGetByteEscaped(self): + def test_get_byte_escaped(self): bytea_escaped = pg.get_bytea_escaped() # error if a parameter is passed self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped) @@ -2618,7 +2618,7 @@ def testGetByteEscaped(self): self.assertIsInstance(r, bool) self.assertIs(r, False) - def testSetByteaEscaped(self): + def test_set_bytea_escaped(self): bytea_escaped = pg.get_bytea_escaped() # error if no parameter is passed self.assertRaises(TypeError, pg.set_bytea_escaped) @@ -2646,7 +2646,7 @@ def testSetByteaEscaped(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'data') - def testSetRowFactorySize(self): + def test_set_row_factory_size(self): queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): @@ -2689,7 +2689,7 @@ def setUpClass(cls): db.close() cls.cls_set_up = True - def testEscapeString(self): + def test_escape_string(self): self.assertTrue(self.cls_set_up) f = pg.escape_string r = f(b'plain') @@ -2707,7 +2707,7 @@ def testEscapeString(self): r = f(r"It's bad to have a \ inside.") self.assertEqual(r, r"It''s bad to have a \\ inside.") - def testEscapeBytea(self): + def test_escape_bytea(self): self.assertTrue(self.cls_set_up) f = pg.escape_bytea r = f(b'plain') diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 8e64949d..1f7b3aac 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -37,7 +37,7 @@ do_not_ask_for_host_reason = 'libpq issue on Windows' -def DB(): +def DB(): # noqa: N802 """Create a DB wrapper object connecting to the test database.""" db = pg.DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) if debug: @@ -52,7 +52,7 @@ class TestAttrDict(unittest.TestCase): cls = pg.AttrDict base = OrderedDict - def testInit(self): + def test_init(self): a = self.cls() self.assertIsInstance(a, self.base) self.assertEqual(a, self.base()) @@ -65,7 +65,7 @@ def testInit(self): self.assertIsInstance(a, self.base) self.assertEqual(a, self.base(items)) - def testIter(self): + def test_iter(self): a = self.cls() self.assertEqual(list(a), []) keys = ['id', 'name', 'age'] @@ -73,7 +73,7 @@ def testIter(self): a = self.cls(items) self.assertEqual(list(a), keys) - def testKeys(self): + def test_keys(self): a = self.cls() self.assertEqual(list(a.keys()), []) keys = ['id', 'name', 'age'] @@ -81,7 +81,7 @@ def testKeys(self): a = self.cls(items) self.assertEqual(list(a.keys()), keys) - def testValues(self): + def test_values(self): a = self.cls() self.assertEqual(list(a.values()), []) items = [('id', 'int'), ('name', 'text')] @@ -89,21 +89,21 @@ def testValues(self): a = self.cls(items) self.assertEqual(list(a.values()), values) - def testItems(self): + def test_items(self): a = self.cls() self.assertEqual(list(a.items()), []) items = [('id', 'int'), ('name', 'text')] a = self.cls(items) self.assertEqual(list(a.items()), items) - def testGet(self): + def test_get(self): a = self.cls([('id', 1)]) try: self.assertEqual(a['id'], 1) except KeyError: self.fail('AttrDict should be readable') - def testSet(self): + def test_set(self): a = self.cls() try: a['id'] = 1 @@ -112,7 +112,7 @@ def testSet(self): else: self.fail('AttrDict should be read-only') - def testDel(self): + def test_del(self): a = self.cls([('id', 1)]) try: del a['id'] @@ -121,7 +121,7 @@ def testDel(self): else: self.fail('AttrDict should be read-only') - def testWriteMethods(self): + def test_write_methods(self): a = self.cls([('id', 1)]) self.assertEqual(a['id'], 1) for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': @@ -132,17 +132,17 @@ def testWriteMethods(self): class TestDBClassInit(unittest.TestCase): """Test proper handling of errors when creating DB instances.""" - def testBadParams(self): + def test_bad_params(self): self.assertRaises(TypeError, pg.DB, invalid=True) # noinspection PyUnboundLocalVariable - def testDeleteDb(self): + def test_delete_db(self): db = DB() del db.db self.assertRaises(pg.InternalError, db.close) del db - def testAsyncQueryBeforeDeletion(self): + def test_async_query_before_deletion(self): db = DB() query = db.send_query('select 1') self.assertEqual(query.getresult(), [(1,)]) @@ -151,7 +151,7 @@ def testAsyncQueryBeforeDeletion(self): del db gc.collect() - def testAsyncQueryAfterDeletion(self): + def test_async_query_after_deletion(self): db = DB() query = db.send_query('select 1') del db @@ -172,7 +172,7 @@ def tearDown(self): except pg.InternalError: pass - def testAllDBAttributes(self): + def test_all_db_attributes(self): attributes = [ 'abort', 'adapter', 'backend_pid', 'begin', @@ -210,19 +210,19 @@ def testAllDBAttributes(self): db_attributes = [a for a in self.db.__dir__() if not a.startswith('_')] self.assertEqual(attributes, db_attributes) - def testAttributeDb(self): + def test_attribute_db(self): self.assertEqual(self.db.db.db, dbname) - def testAttributeDbname(self): + def test_attribute_dbname(self): self.assertEqual(self.db.dbname, dbname) - def testAttributeError(self): + def test_attribute_error(self): error = self.db.error self.assertTrue(not error or 'krb5_' in error) self.assertEqual(self.db.error, self.db.db.error) @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason) - def testAttributeHost(self): + def test_attribute_host(self): if dbhost and not dbhost.startswith('/'): host = dbhost else: @@ -231,61 +231,61 @@ def testAttributeHost(self): self.assertEqual(self.db.host, host) self.assertEqual(self.db.db.host, host) - def testAttributeOptions(self): + def test_attribute_options(self): no_options = '' options = self.db.options self.assertEqual(options, no_options) self.assertEqual(options, self.db.db.options) - def testAttributePort(self): + def test_attribute_port(self): def_port = 5432 port = self.db.port self.assertIsInstance(port, int) self.assertEqual(port, dbport or def_port) self.assertEqual(port, self.db.db.port) - def testAttributeProtocolVersion(self): + def test_attribute_protocol_version(self): protocol_version = self.db.protocol_version self.assertIsInstance(protocol_version, int) self.assertTrue(2 <= protocol_version < 4) self.assertEqual(protocol_version, self.db.db.protocol_version) - def testAttributeServerVersion(self): + def test_attribute_server_version(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) self.assertTrue(100000 <= server_version < 160000) self.assertEqual(server_version, self.db.db.server_version) - def testAttributeSocket(self): + def test_attribute_socket(self): socket = self.db.socket self.assertIsInstance(socket, int) self.assertGreaterEqual(socket, 0) - def testAttributeBackendPid(self): + def test_attribute_backend_pid(self): backend_pid = self.db.backend_pid self.assertIsInstance(backend_pid, int) self.assertGreaterEqual(backend_pid, 1) - def testAttributeSslInUse(self): + def test_attribute_ssl_in_use(self): ssl_in_use = self.db.ssl_in_use self.assertIsInstance(ssl_in_use, bool) self.assertFalse(ssl_in_use) - def testAttributeSslAttributes(self): + def test_attribute_ssl_attributes(self): ssl_attributes = self.db.ssl_attributes self.assertIsInstance(ssl_attributes, dict) self.assertEqual(ssl_attributes, { 'cipher': None, 'compression': None, 'key_bits': None, 'library': None, 'protocol': None}) - def testAttributeStatus(self): + def test_attribute_status(self): status_ok = 1 status = self.db.status self.assertIsInstance(status, int) self.assertEqual(status, status_ok) self.assertEqual(status, self.db.db.status) - def testAttributeUser(self): + def test_attribute_user(self): no_user = 'Deprecated facility' user = self.db.user self.assertTrue(user) @@ -293,29 +293,29 @@ def testAttributeUser(self): self.assertNotEqual(user, no_user) self.assertEqual(user, self.db.db.user) - def testMethodEscapeLiteral(self): + def test_method_escape_literal(self): self.assertEqual(self.db.escape_literal(''), "''") - def testMethodEscapeIdentifier(self): + def test_method_escape_identifier(self): self.assertEqual(self.db.escape_identifier(''), '""') - def testMethodEscapeString(self): + def test_method_escape_string(self): self.assertEqual(self.db.escape_string(''), '') - def testMethodEscapeBytea(self): + def test_method_escape_bytea(self): self.assertEqual(self.db.escape_bytea('').replace( '\\x', '').replace('\\', ''), '') - def testMethodUnescapeBytea(self): + def test_method_unescape_bytea(self): self.assertEqual(self.db.unescape_bytea(''), b'') - def testMethodDecodeJson(self): + def test_method_decode_json(self): self.assertEqual(self.db.decode_json('{}'), {}) - def testMethodEncodeJson(self): + def test_method_encode_json(self): self.assertEqual(self.db.encode_json({}), '{}') - def testMethodQuery(self): + def test_method_query(self): query = self.db.query query("select 1+1") query("select 1+$1+$2", 2, 3) @@ -323,23 +323,23 @@ def testMethodQuery(self): query("select 1+$1+$2", [2, 3]) query("select 1+$1", 1) - def testMethodQueryEmpty(self): + def test_method_query_empty(self): self.assertRaises(ValueError, self.db.query, '') - def testMethodQueryDataError(self): + def test_method_query_data_error(self): try: self.db.query("select 1/0") except pg.DataError as error: # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') - def testMethodEndcopy(self): + def test_method_endcopy(self): try: self.db.endcopy() except OSError: pass - def testMethodClose(self): + def test_method_close(self): self.db.close() try: self.db.reset() @@ -354,7 +354,7 @@ def testMethodClose(self): self.assertRaises(pg.InternalError, getattr, self.db, 'error') self.assertRaises(pg.InternalError, getattr, self.db, 'absent') - def testMethodReset(self): + def test_method_reset(self): con = self.db.db self.db.reset() self.assertIs(self.db.db, con) @@ -362,7 +362,7 @@ def testMethodReset(self): self.db.close() self.assertRaises(pg.InternalError, self.db.reset) - def testMethodReopen(self): + def test_method_reopen(self): con = self.db.db self.db.reopen() self.assertIsNot(self.db.db, con) @@ -374,7 +374,7 @@ def testMethodReopen(self): self.db.query("select 1+1") self.db.close() - def testExistingConnection(self): + def test_existing_connection(self): db = pg.DB(self.db.db) self.assertIsNotNone(db.db) self.assertEqual(self.db.db, db.db) @@ -391,7 +391,7 @@ def testExistingConnection(self): db = pg.DB(db=self.db.db) self.assertEqual(self.db.db, db.db) - def testExistingDbApi2Connection(self): + def test_existing_db_api2_connection(self): class DBApi2Con: @@ -461,7 +461,7 @@ def tearDown(self): self.doCleanups() self.db.close() - def createTable(self, table, definition, + def create_table(self, table, definition, temporary=True, oids=None, values=None): query = self.db.query if '"' not in table or '.' in table: @@ -491,14 +491,14 @@ def createTable(self, table, definition, q = f"insert into {table} values ({values})" query(q, params) - def testClassName(self): + def test_class_name(self): self.assertEqual(self.db.__class__.__name__, 'DB') - def testModuleName(self): + def test_module_name(self): self.assertEqual(self.db.__module__, 'pg') self.assertEqual(self.db.__class__.__module__, 'pg') - def testEscapeLiteral(self): + def test_escape_literal(self): f = self.db.escape_literal r = f(b"plain") self.assertIsInstance(r, bytes) @@ -517,7 +517,7 @@ def testEscapeLiteral(self): self.assertEqual(f('No "quotes" must be escaped.'), "'No \"quotes\" must be escaped.'") - def testEscapeIdentifier(self): + def test_escape_identifier(self): f = self.db.escape_identifier r = f(b"plain") self.assertIsInstance(r, bytes) @@ -536,7 +536,7 @@ def testEscapeIdentifier(self): self.assertEqual(f('All "quotes" must be escaped.'), '"All ""quotes"" must be escaped."') - def testEscapeString(self): + def test_escape_string(self): f = self.db.escape_string r = f(b"plain") self.assertIsInstance(r, bytes) @@ -553,7 +553,7 @@ def testEscapeString(self): self.assertEqual(f(r"It's fine to have a \ inside."), r"It''s fine to have a \ inside.") - def testEscapeBytea(self): + def test_escape_bytea(self): f = self.db.escape_bytea # note that escape_byte always returns hex output since Pg 9.0, # regardless of the bytea_output setting @@ -571,7 +571,7 @@ def testEscapeBytea(self): self.assertEqual(r, '\\x64617320697327206bc3a47365') self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21') - def testUnescapeBytea(self): + def test_unescape_bytea(self): f = self.db.unescape_bytea r = f(b'plain') self.assertIsInstance(r, bytes) @@ -591,7 +591,7 @@ def testUnescapeBytea(self): b'\\x746861742773206be47365') self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21') - def testDecodeJson(self): + def test_decode_json(self): f = self.db.decode_json self.assertIsNone(f('null')) data = { @@ -610,7 +610,7 @@ def testDecodeJson(self): self.assertIsInstance(r['tags'], list) self.assertIsInstance(r['stock'], dict) - def testEncodeJson(self): + def test_encode_json(self): f = self.db.encode_json self.assertEqual(f(None), 'null') data = { @@ -623,7 +623,7 @@ def testEncodeJson(self): self.assertIsInstance(r, str) self.assertEqual(r, text) - def testGetParameter(self): + def test_get_parameter(self): f = self.db.get_parameter self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, None) @@ -660,14 +660,14 @@ def testGetParameter(self): self.assertIs(r, s) self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'}) - def testGetParameterServerVersion(self): + def test_get_parameter_server_version(self): r = self.db.get_parameter('server_version_num') self.assertIsInstance(r, str) s = self.db.server_version self.assertIsInstance(s, int) self.assertEqual(r, str(s)) - def testGetParameterAll(self): + def test_get_parameter_all(self): f = self.db.get_parameter r = f('all') self.assertIsInstance(r, dict) @@ -676,7 +676,7 @@ def testGetParameterAll(self): self.assertEqual(r['DateStyle'], 'ISO, YMD') self.assertEqual(r['bytea_output'], 'hex') - def testSetParameter(self): + def test_set_parameter(self): f = self.db.set_parameter g = self.db.get_parameter self.assertRaises(TypeError, f) @@ -720,7 +720,7 @@ def testSetParameter(self): self.assertEqual(g('standard_conforming_strings'), 'on') self.assertEqual(g('datestyle'), 'ISO, YMD') - def testResetParameter(self): + def test_reset_parameter(self): db = DB() f = db.set_parameter g = db.get_parameter @@ -761,7 +761,7 @@ def testResetParameter(self): self.assertEqual(g('standard_conforming_strings'), scs) db.close() - def testResetParameterAll(self): + def test_reset_parameter_all(self): db = DB() f = db.set_parameter self.assertRaises(ValueError, f, 'all', 0) @@ -782,7 +782,7 @@ def testResetParameterAll(self): self.assertEqual(g('standard_conforming_strings'), scs) db.close() - def testSetParameterLocal(self): + def test_set_parameter_local(self): f = self.db.set_parameter g = self.db.get_parameter self.assertEqual(g('standard_conforming_strings'), 'on') @@ -792,7 +792,7 @@ def testSetParameterLocal(self): self.db.end() self.assertEqual(g('standard_conforming_strings'), 'on') - def testSetParameterSession(self): + def test_set_parameter_session(self): f = self.db.set_parameter g = self.db.get_parameter self.assertEqual(g('standard_conforming_strings'), 'on') @@ -802,7 +802,7 @@ def testSetParameterSession(self): self.db.end() self.assertEqual(g('standard_conforming_strings'), 'off') - def testReset(self): + def test_reset(self): db = DB() default_datestyle = db.get_parameter('datestyle') changed_datestyle = 'ISO, DMY' @@ -823,7 +823,7 @@ def testReset(self): self.assertEqual(r, default_datestyle) db.close() - def testReopen(self): + def test_reopen(self): db = DB() default_datestyle = db.get_parameter('datestyle') changed_datestyle = 'ISO, DMY' @@ -842,21 +842,21 @@ def testReopen(self): self.assertEqual(r, default_datestyle) db.close() - def testCreateTable(self): + def test_create_table(self): table = 'test hello world' values = [(2, "World!"), (1, "Hello")] - self.createTable(table, "n smallint, t varchar", + self.create_table(table, "n smallint, t varchar", temporary=True, oids=False, values=values) r = self.db.query(f'select t from "{table}" order by n').getresult() r = ', '.join(row[0] for row in r) self.assertEqual(r, "Hello, World!") - def testCreateTableWithOids(self): + def test_create_table_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") table = 'test hello world' values = [(2, "World!"), (1, "Hello")] - self.createTable(table, "n smallint, t varchar", + self.create_table(table, "n smallint, t varchar", temporary=True, oids=True, values=values) r = self.db.query(f'select t from "{table}" order by n').getresult() r = ', '.join(row[0] for row in r) @@ -864,10 +864,10 @@ def testCreateTableWithOids(self): r = self.db.query(f'select oid from "{table}" limit 1').getresult() self.assertIsInstance(r[0][0], int) - def testQuery(self): + def test_query(self): query = self.db.query table = 'test_table' - self.createTable(table, "n integer", oids=False) + self.create_table(table, "n integer", oids=False) q = "insert into test_table values (1)" r = query(q) self.assertIsInstance(r, str) @@ -898,12 +898,12 @@ def testQuery(self): self.assertIsInstance(r, str) self.assertEqual(r, '5') - def testQueryWithOids(self): + def test_query_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") query = self.db.query table = 'test_table' - self.createTable(table, "n integer", oids=True) + self.create_table(table, "n integer", oids=True) q = "insert into test_table values (1)" r = query(q) self.assertIsInstance(r, int) @@ -932,15 +932,15 @@ def testQueryWithOids(self): self.assertIsInstance(r, str) self.assertEqual(r, '5') - def testMultipleQueries(self): + def test_multiple_queries(self): self.assertEqual(self.db.query( "create temporary table test_multi (n integer);" "insert into test_multi values (4711);" "select n from test_multi").getresult()[0][0], 4711) - def testQueryWithParams(self): + def test_query_with_params(self): query = self.db.query - self.createTable('test_table', 'n1 integer, n2 integer', oids=False) + self.create_table('test_table', 'n1 integer, n2 integer', oids=False) q = "insert into test_table values ($1, $2)" r = query(q, (1, 2)) self.assertEqual(r, '1') @@ -963,17 +963,17 @@ def testQueryWithParams(self): r = query(q, 4) self.assertEqual(r, '3') - def testEmptyQuery(self): + def test_empty_query(self): self.assertRaises(ValueError, self.db.query, '') - def testQueryDataError(self): + def test_query_data_error(self): try: self.db.query("select 1/0") except pg.DataError as error: # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') - def testQueryFormatted(self): + def test_query_formatted(self): f = self.db.query_formatted t = True if pg.get_bool() else 't' # test with tuple @@ -1001,7 +1001,7 @@ def testQueryFormatted(self): r = q.getresult()[0][0] self.assertEqual(r, 'alphabetagammadeltaepsilon') - def testQueryFormattedWithAny(self): + def test_query_formatted_with_any(self): f = self.db.query_formatted q = "select 2 = any(%s)" r = f(q, [[1, 3]]).getresult()[0][0] @@ -1013,7 +1013,7 @@ def testQueryFormattedWithAny(self): r = f(q, [[None]]).getresult()[0][0] self.assertIsNone(r) - def testQueryFormattedWithoutParams(self): + def test_query_formatted_without_params(self): f = self.db.query_formatted q = "select 42" r = f(q).getresult()[0][0] @@ -1025,19 +1025,19 @@ def testQueryFormattedWithoutParams(self): r = f(q, {}).getresult()[0][0] self.assertEqual(r, 42) - def testPrepare(self): + def test_prepare(self): p = self.db.prepare self.assertIsNone(p('my query', "select 'hello'")) self.assertIsNone(p('my other query', "select 'world'")) self.assertRaises( pg.ProgrammingError, p, 'my query', "select 'hello, too'") - def testPrepareUnnamed(self): + def test_prepare_unnamed(self): p = self.db.prepare self.assertIsNone(p('', "select null")) self.assertIsNone(p(None, "select null")) - def testQueryPreparedWithoutParams(self): + def test_query_prepared_without_params(self): f = self.db.query_prepared self.assertRaises(pg.OperationalError, f, 'q') p = self.db.prepare @@ -1048,7 +1048,7 @@ def testQueryPreparedWithoutParams(self): r = f('q2').getresult()[0][0] self.assertEqual(r, 42) - def testQueryPreparedWithParams(self): + def test_query_prepared_with_params(self): p = self.db.prepare p('sum', "select 1 + $1 + $2 + $3") p('cat', "select initcap($1) || ', ' || $2 || '!'") @@ -1058,7 +1058,7 @@ def testQueryPreparedWithParams(self): r = f('cat', 'hello', 'world').getresult()[0][0] self.assertEqual(r, 'Hello, world!') - def testQueryPreparedUnnamedWithOutParams(self): + def test_query_prepared_unnamed_with_out_params(self): f = self.db.query_prepared self.assertRaises(pg.OperationalError, f, None) self.assertRaises(pg.OperationalError, f, '') @@ -1076,7 +1076,7 @@ def testQueryPreparedUnnamedWithOutParams(self): r = f('').getresult()[0][0] self.assertEqual(r, 'none') - def testQueryPreparedUnnamedWithParams(self): + def test_query_prepared_unnamed_with_params(self): p = self.db.prepare p('', "select 1 + $1 + $2") f = self.db.query_prepared @@ -1091,13 +1091,13 @@ def testQueryPreparedUnnamedWithParams(self): r = f(None, 3, 4).getresult()[0][0] self.assertEqual(r, 9) - def testDescribePrepared(self): + def test_describe_prepared(self): self.db.prepare('count', "select 1 as first, 2 as second") f = self.db.describe_prepared r = f('count').listfields() self.assertEqual(r, ('first', 'second')) - def testDescribePreparedUnnamed(self): + def test_describe_prepared_unnamed(self): self.db.prepare('', "select null as anon") f = self.db.describe_prepared r = f().listfields() @@ -1107,7 +1107,7 @@ def testDescribePreparedUnnamed(self): r = f('').listfields() self.assertEqual(r, ('anon',)) - def testDeletePrepared(self): + def test_delete_prepared(self): f = self.db.delete_prepared f() e = pg.OperationalError @@ -1125,27 +1125,27 @@ def testDeletePrepared(self): self.assertRaises(e, f, 'q1') self.assertRaises(e, f, 'q2') - def testPkey(self): + def test_pkey(self): query = self.db.query pkey = self.db.pkey self.assertRaises(KeyError, pkey, 'test') for t in ('pkeytest', 'primary key test'): - self.createTable(f'{t}0', 'a smallint') - self.createTable(f'{t}1', 'b smallint primary key') - self.createTable(f'{t}2', 'c smallint, d smallint primary key') - self.createTable( + self.create_table(f'{t}0', 'a smallint') + self.create_table(f'{t}1', 'b smallint primary key') + self.create_table(f'{t}2', 'c smallint, d smallint primary key') + self.create_table( f'{t}3', 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (f, h)') - self.createTable( + self.create_table( f'{t}4', 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (h, f)') - self.createTable( + self.create_table( f'{t}5', 'more_than_one_letter varchar primary key') - self.createTable( + self.create_table( f'{t}6', '"with space" date primary key') - self.createTable( + self.create_table( f'{t}7', 'a_very_long_column_name varchar, "with space" date, "42" int,' ' primary key (a_very_long_column_name, "with space", "42")') @@ -1181,7 +1181,7 @@ def testPkey(self): # we get the changed primary key when the cache is flushed self.assertEqual(pkey(f'{t}1', flush=True), 'x') - def testGetDatabases(self): + def test_get_databases(self): databases = self.db.get_databases() self.assertIn('template0', databases) self.assertIn('template1', databases) @@ -1189,7 +1189,7 @@ def testGetDatabases(self): self.assertIn('postgres', databases) self.assertIn(dbname, databases) - def testGetTables(self): + def test_get_tables(self): get_tables = self.db.get_tables tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe', 'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321', @@ -1208,7 +1208,7 @@ def testGetTables(self): self.assertNotEqual(t, 'information_schema') self.assertFalse(t.startswith('pg_')) for t in tables: - self.createTable(t, 'as select 0', temporary=False) + self.create_table(t, 'as select 0', temporary=False) current_tables = get_tables() new_tables = [t for t in current_tables if t not in before_tables] expected_new_tables = ['public.' + ( @@ -1218,7 +1218,7 @@ def testGetTables(self): after_tables = get_tables() self.assertEqual(after_tables, before_tables) - def testGetSystemTables(self): + def test_get_system_tables(self): get_tables = self.db.get_tables result = get_tables() self.assertNotIn('pg_catalog.pg_class', result) @@ -1230,7 +1230,7 @@ def testGetSystemTables(self): self.assertIn('pg_catalog.pg_class', result) self.assertNotIn('information_schema.tables', result) - def testGetRelations(self): + def test_get_relations(self): get_relations = self.db.get_relations result = get_relations() self.assertIn('public.test', result) @@ -1248,7 +1248,7 @@ def testGetRelations(self): self.assertNotIn('public.test', result) self.assertNotIn('public.test_view', result) - def testGetSystemRelations(self): + def test_get_system_relations(self): get_relations = self.db.get_relations result = get_relations() self.assertNotIn('pg_catalog.pg_class', result) @@ -1260,7 +1260,7 @@ def testGetSystemRelations(self): self.assertIn('pg_catalog.pg_class', result) self.assertIn('information_schema.tables', result) - def testGetAttnames(self): + def test_get_attnames(self): get_attnames = self.db.get_attnames self.assertRaises(pg.ProgrammingError, self.db.get_attnames, 'does_not_exist') @@ -1278,7 +1278,7 @@ def testGetAttnames(self): i2='int', i4='int', i8='int', d='num', f4='float', f8='float', m='money', v4='text', c4='text', t='text')) - self.createTable('test_table', + self.create_table('test_table', 'n int, alpha smallint, beta bool,' ' gamma char(5), tau text, v varchar(3)') r = get_attnames('test_table') @@ -1292,10 +1292,10 @@ def testGetAttnames(self): n='int', alpha='int', beta='bool', gamma='text', tau='text', v='text')) - def testGetAttnamesWithQuotes(self): + def test_get_attnames_with_quotes(self): get_attnames = self.db.get_attnames table = 'test table for get_attnames()' - self.createTable( + self.create_table( table, '"Prime!" smallint, "much space" integer, "Questions?" text') r = get_attnames(table) @@ -1308,7 +1308,7 @@ def testGetAttnamesWithQuotes(self): self.assertEqual(r, { 'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'}) table = 'yet another test table for get_attnames()' - self.createTable(table, + self.create_table(table, 'a smallint, b integer, c bigint,' ' e numeric, f real, f2 double precision, m money,' ' x smallint, y smallint, z smallint,' @@ -1333,9 +1333,9 @@ def testGetAttnamesWithQuotes(self): 'u': 'text', 't': 'text', 'v': 'text', 'y': 'int', 'x': 'int', 'z': 'int'}) - def testGetAttnamesWithRegtypes(self): + def test_get_attnames_with_regtypes(self): get_attnames = self.db.get_attnames - self.createTable( + self.create_table( 'test_table', 'n int, alpha smallint, beta bool,' ' gamma char(5), tau text, v varchar(3)') use_regtypes = self.db.use_regtypes @@ -1351,9 +1351,9 @@ def testGetAttnamesWithRegtypes(self): n='integer', alpha='smallint', beta='boolean', gamma='character', tau='text', v='character varying')) - def testGetAttnamesWithoutRegtypes(self): + def test_get_attnames_without_regtypes(self): get_attnames = self.db.get_attnames - self.createTable( + self.create_table( 'test_table', 'n int, alpha smallint, beta bool,' ' gamma char(5), tau text, v varchar(3)') use_regtypes = self.db.use_regtypes @@ -1369,12 +1369,12 @@ def testGetAttnamesWithoutRegtypes(self): n='int', alpha='int', beta='bool', gamma='text', tau='text', v='text')) - def testGetAttnamesIsCached(self): + def test_get_attnames_is_cached(self): get_attnames = self.db.get_attnames int_type = 'integer' if self.regtypes else 'int' text_type = 'text' query = self.db.query - self.createTable('test_table', 'col int') + self.create_table('test_table', 'col int') r = get_attnames("test_table") self.assertIsInstance(r, dict) self.assertEqual(r, dict(col=int_type)) @@ -1395,7 +1395,7 @@ def testGetAttnamesIsCached(self): r = get_attnames("test_table", flush=True) self.assertEqual(r, dict()) - def testGetAttnamesIsOrdered(self): + def test_get_attnames_is_ordered(self): get_attnames = self.db.get_attnames r = get_attnames('test', flush=True) self.assertIsInstance(r, OrderedDict) @@ -1414,7 +1414,7 @@ def testGetAttnamesIsOrdered(self): r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' - self.createTable( + self.create_table( table, 'n int, alpha smallint, v varchar(3),' ' gamma char(5), tau text, beta bool') r = get_attnames(table) @@ -1434,8 +1434,8 @@ def testGetAttnamesIsOrdered(self): else: self.skipTest('OrderedDict is not supported') - def testGetAttnamesIsAttrDict(self): - AttrDict = pg.AttrDict + def test_get_attnames_is_attr_dict(self): + AttrDict = pg.AttrDict # noqa: N806 get_attnames = self.db.get_attnames r = get_attnames('test', flush=True) self.assertIsInstance(r, AttrDict) @@ -1453,7 +1453,7 @@ def testGetAttnamesIsAttrDict(self): r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' - self.createTable( + self.create_table( table, 'n int, alpha smallint, v varchar(3),' ' gamma char(5), tau text, beta bool') r = get_attnames(table) @@ -1470,7 +1470,7 @@ def testGetAttnamesIsAttrDict(self): r = ' '.join(list(r.keys())) self.assertEqual(r, 'n alpha v gamma tau beta') - def testGetGenerated(self): + def test_get_generated(self): get_generated = self.db.get_generated server_version = self.db.server_version if server_version >= 100000: @@ -1483,7 +1483,7 @@ def testGetGenerated(self): self.assertFalse(r) if server_version >= 100000: table = 'test_get_generated_1' - self.createTable( + self.create_table( table, 'i int generated always as identity primary key,' ' j int generated always as identity,' @@ -1494,7 +1494,7 @@ def testGetGenerated(self): self.assertEqual(r, {'i', 'j'}) if server_version >= 120000: table = 'test_get_generated_2' - self.createTable( + self.create_table( table, 'n int, m int generated always as (n + 3) stored,' ' i int generated always as identity,' @@ -1503,21 +1503,21 @@ def testGetGenerated(self): self.assertIsInstance(r, frozenset) self.assertEqual(r, {'m', 'i'}) - def testGetGeneratedIsCached(self): + def test_get_generated_is_cached(self): server_version = self.db.server_version if server_version < 100000: self.skipTest("database does not support generated columns") get_generated = self.db.get_generated query = self.db.query table = 'test_get_generated_2' - self.createTable(table, 'i int primary key') + self.create_table(table, 'i int primary key') self.assertFalse(get_generated(table)) query(f'alter table {table} alter column i' ' add generated always as identity') self.assertFalse(get_generated(table)) self.assertEqual(get_generated(table, flush=True), {'i'}) - def testHasTablePrivilege(self): + def test_has_table_privilege(self): can = self.db.has_table_privilege self.assertEqual(can('test'), True) self.assertEqual(can('test', 'select'), True) @@ -1538,13 +1538,13 @@ def testHasTablePrivilege(self): self.assertEqual(can('pg_views', 'select'), True) self.assertEqual(can('pg_views', 'delete'), False) - def testGet(self): + def test_get(self): get = self.db.get query = self.db.query table = 'get_test_table' self.assertRaises(TypeError, get) self.assertRaises(TypeError, get, table) - self.createTable(table, 'n integer, t text', + self.create_table(table, 'n integer, t text', values=enumerate('xyz', start=1)) self.assertRaises(pg.ProgrammingError, get, table, 2) r = get(table, 2, 'n') @@ -1593,13 +1593,13 @@ def testGet(self): s.pop('n') self.assertRaises(KeyError, get, table, s) - def testGetWithOids(self): + def test_get_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") get = self.db.get query = self.db.query table = 'get_with_oid_test_table' - self.createTable(table, 'n integer, t text', oids=True, + self.create_table(table, 'n integer, t text', oids=True, values=enumerate('xyz', start=1)) self.assertRaises(pg.ProgrammingError, get, table, 2) self.assertRaises(KeyError, get, table, {}, 'oid') @@ -1659,10 +1659,10 @@ def testGetWithOids(self): self.assertEqual(r['n'], 3) self.assertNotEqual(r[qoid], oid) - def testGetWithCompositeKey(self): + def test_get_with_composite_key(self): get = self.db.get table = 'get_test_table_1' - self.createTable( + self.create_table( table, 'n integer primary key, t text', values=enumerate('abc', start=1)) self.assertEqual(get(table, 2)['t'], 'b') @@ -1674,7 +1674,7 @@ def testGetWithCompositeKey(self): self.assertEqual(get(table, ('a',), ('t',))['n'], 1) self.assertEqual(get(table, ['c'], ['t'])['n'], 3) table = 'get_test_table_2' - self.createTable( + self.create_table( table, 'n integer, m integer, t text, primary key (n, m)', values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) for n in range(3) for m in range(2)]) @@ -1691,10 +1691,10 @@ def testGetWithCompositeKey(self): self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c') self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f') - def testGetWithQuotedNames(self): + def test_get_with_quoted_names(self): get = self.db.get table = 'test table for get()' - self.createTable( + self.create_table( table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(17, 1001, 'No!')]) @@ -1704,7 +1704,7 @@ def testGetWithQuotedNames(self): self.assertEqual(r['much space'], 1001) self.assertEqual(r['Questions?'], 'No!') - def testGetFromView(self): + def test_get_from_view(self): self.db.query('delete from test where i4=14') self.db.query('insert into test (i4, v4) values(' "14, 'abc4')") @@ -1712,10 +1712,10 @@ def testGetFromView(self): self.assertIn('v4', r) self.assertEqual(r['v4'], 'abc4') - def testGetLittleBobbyTables(self): + def test_get_little_bobby_tables(self): get = self.db.get query = self.db.query - self.createTable( + self.create_table( 'test_students', 'firstname varchar primary key, nickname varchar, grade char(2)', values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'), @@ -1748,13 +1748,13 @@ def testGetLittleBobbyTables(self): self.assertEqual(len(r), 3) self.assertEqual(r[1][2], 'D-') - def testInsert(self): + def test_insert(self): insert = self.db.insert query = self.db.query bool_on = pg.get_bool() decimal = pg.get_decimal() table = 'insert_test_table' - self.createTable( + self.create_table( table, 'i2 smallint, i4 integer, i8 bigint,' ' d numeric, f4 real, f8 double precision, m money,' ' v4 varchar(4), c4 char(4), t text,' @@ -1840,12 +1840,12 @@ def testInsert(self): self.assertEqual(data, expect) query(f'truncate table "{table}"') - def testInsertWithOids(self): + def test_insert_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") insert = self.db.insert query = self.db.query - self.createTable('test_table', 'n int', oids=True) + self.create_table('test_table', 'n int', oids=True) self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1) r = insert('test_table', n=1) self.assertIsInstance(r, dict) @@ -1910,11 +1910,11 @@ def testInsertWithOids(self): r = ' '.join(str(row[0]) for row in query(q).getresult()) self.assertEqual(r, '6 7') - def testInsertWithQuotedNames(self): + def test_insert_with_quoted_names(self): insert = self.db.insert query = self.db.query table = 'test table for insert()' - self.createTable(table, '"Prime!" smallint primary key,' + self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'} r = insert(table, r) @@ -1929,7 +1929,7 @@ def testInsertWithQuotedNames(self): self.assertEqual(r['much space'], 2002) self.assertEqual(r['Questions?'], 'What?') - def testInsertIntoView(self): + def test_insert_into_view(self): insert = self.db.insert query = self.db.query query("truncate table test") @@ -1955,7 +1955,7 @@ def testInsertIntoView(self): r = query(q).getresult() self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')]) - def testInsertWithGeneratedColumns(self): + def test_insert_with_generated_columns(self): insert = self.db.insert get = self.db.get server_version = self.db.server_version @@ -1971,7 +1971,7 @@ def testInsertWithGeneratedColumns(self): table_def += ', j int generated always as (i + 7) stored' else: table_def += ', j int not null default 42' - self.createTable(table, table_def) + self.create_table(table, table_def) i, d = 35, 1001 j = i + 7 r = insert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) @@ -1981,13 +1981,13 @@ def testInsertWithGeneratedColumns(self): self.assertIsInstance(r, dict) self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) - def testUpdate(self): + def test_update(self): update = self.db.update query = self.db.query self.assertRaises(pg.ProgrammingError, update, 'test', i2=2, i4=4, i8=8) table = 'update_test_table' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', values=enumerate('xyz', start=1)) self.assertRaises(pg.DatabaseError, self.db.get, table, 4) r = self.db.get(table, 2) @@ -1998,13 +1998,13 @@ def testUpdate(self): r = query(q).getresult()[0][0] self.assertEqual(r, 'u') - def testUpdateWithOids(self): + def test_update_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") update = self.db.update get = self.db.get query = self.db.query - self.createTable('test_table', 'n int', oids=True, values=[1]) + self.create_table('test_table', 'n int', oids=True, values=[1]) s = get('test_table', 1, 'n') self.assertIsInstance(s, dict) self.assertEqual(s['n'], 1) @@ -2078,13 +2078,13 @@ def testUpdateWithOids(self): r = query(q).getresult() self.assertEqual(r, [(1, 3), (4, 7)]) - def testUpdateWithoutOid(self): + def test_update_without_oid(self): update = self.db.update query = self.db.query self.assertRaises(pg.ProgrammingError, update, 'test', i2=2, i4=4, i8=8) table = 'update_test_table' - self.createTable(table, 'n integer primary key, t text', oids=False, + self.create_table(table, 'n integer primary key, t text', oids=False, values=enumerate('xyz', start=1)) r = self.db.get(table, 2) r['t'] = 'u' @@ -2094,11 +2094,11 @@ def testUpdateWithoutOid(self): r = query(q).getresult()[0][0] self.assertEqual(r, 'u') - def testUpdateWithCompositeKey(self): + def test_update_with_composite_key(self): update = self.db.update query = self.db.query table = 'update_test_table_1' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', values=enumerate('abc', start=1)) self.assertRaises(KeyError, update, table, dict(t='b')) s = dict(n=2, t='d') @@ -2121,7 +2121,7 @@ def testUpdateWithCompositeKey(self): self.assertEqual(len(r), 0) query(f'drop table "{table}"') table = 'update_test_table_2' - self.createTable(table, + self.create_table(table, 'n integer, m integer, t text, primary key (n, m)', values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) for n in range(3) for m in range(2)]) @@ -2132,11 +2132,11 @@ def testUpdateWithCompositeKey(self): r = [r[0] for r in query(q).getresult()] self.assertEqual(r, ['c', 'x']) - def testUpdateWithQuotedNames(self): + def test_update_with_quoted_names(self): update = self.db.update query = self.db.query table = 'test table for update()' - self.createTable(table, '"Prime!" smallint primary key,' + self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(13, 3003, 'Why!')]) r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'} @@ -2152,7 +2152,7 @@ def testUpdateWithQuotedNames(self): self.assertEqual(r['much space'], 7007) self.assertEqual(r['Questions?'], 'When?') - def testUpdateWithGeneratedColumns(self): + def test_update_with_generated_columns(self): update = self.db.update get = self.db.get query = self.db.query @@ -2169,7 +2169,7 @@ def testUpdateWithGeneratedColumns(self): table_def += ', j int generated always as (i + 7) stored' else: table_def += ', j int not null default 42' - self.createTable(table, table_def) + self.create_table(table, table_def) i, d = 35, 1001 j = i + 7 r = query(f'insert into {table} (i, d) values ({i}, {d})') @@ -2184,13 +2184,13 @@ def testUpdateWithGeneratedColumns(self): j += 1 self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) - def testUpsert(self): + def test_upsert(self): upsert = self.db.upsert query = self.db.query self.assertRaises(pg.ProgrammingError, upsert, 'test', i2=2, i4=4, i8=8) table = 'upsert_test_table' - self.createTable(table, 'n integer primary key, t text') + self.create_table(table, 'n integer primary key, t text') s = dict(n=1, t='x') r = upsert(table, s) self.assertIs(r, s) @@ -2257,13 +2257,13 @@ def testUpsert(self): r = query(q).getresult() self.assertEqual(r, [(1, 'x2'), (2, 'y3')]) - def testUpsertWithOids(self): + def test_upsert_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") upsert = self.db.upsert get = self.db.get query = self.db.query - self.createTable('test_table', 'n int', oids=True, values=[1]) + self.create_table('test_table', 'n int', oids=True, values=[1]) self.assertRaises(pg.ProgrammingError, upsert, 'test_table', dict(n=2)) r = get('test_table', 1, 'n') @@ -2338,11 +2338,11 @@ def testUpsertWithOids(self): q = query("select n, m from test_table order by n limit 3") self.assertEqual(q.getresult(), [(1, 5), (2, 10)]) - def testUpsertWithCompositeKey(self): + def test_upsert_with_composite_key(self): upsert = self.db.upsert query = self.db.query table = 'upsert_test_table_2' - self.createTable( + self.create_table( table, 'n integer, m integer, t text, primary key (n, m)') s = dict(n=1, m=2, t='x') r = upsert(table, s) @@ -2400,11 +2400,11 @@ def testUpsertWithCompositeKey(self): r = query(q).getresult() self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')]) - def testUpsertWithQuotedNames(self): + def test_upsert_with_quoted_names(self): upsert = self.db.upsert query = self.db.query table = 'test table for upsert()' - self.createTable(table, '"Prime!" smallint primary key,' + self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'} r = upsert(table, s) @@ -2424,7 +2424,7 @@ def testUpsertWithQuotedNames(self): r = query(q).getresult() self.assertEqual(r, [(31, 9009, 'No.')]) - def testUpsertWithGeneratedColumns(self): + def test_upsert_with_generated_columns(self): upsert = self.db.upsert get = self.db.get server_version = self.db.server_version @@ -2440,7 +2440,7 @@ def testUpsertWithGeneratedColumns(self): table_def += ', j int generated always as (i + 7) stored' else: table_def += ', j int not null default 42' - self.createTable(table, table_def) + self.create_table(table, table_def) i, d = 35, 1001 j = i + 7 r = upsert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) @@ -2455,7 +2455,7 @@ def testUpsertWithGeneratedColumns(self): r = get(table, d) self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) - def testClear(self): + def test_clear(self): clear = self.db.clear f = False if pg.get_bool() else 'f' r = clear('test') @@ -2463,7 +2463,7 @@ def testClear(self): i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='') self.assertEqual(r, result) table = 'clear_test_table' - self.createTable( + self.create_table( table, 'n integer, f float, b boolean, d date, t text') r = clear(table) result = dict(n=0, f=0, b=f, d='', t='') @@ -2476,10 +2476,10 @@ def testClear(self): result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=1) self.assertEqual(r, result) - def testClearWithQuotedNames(self): + def test_clear_with_quoted_names(self): clear = self.db.clear table = 'test table for clear()' - self.createTable( + self.create_table( table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') r = clear(table) @@ -2488,13 +2488,13 @@ def testClearWithQuotedNames(self): self.assertEqual(r['much space'], 0) self.assertEqual(r['Questions?'], '') - def testDelete(self): + def test_delete(self): delete = self.db.delete query = self.db.query self.assertRaises(pg.ProgrammingError, delete, 'test', dict(i2=2, i4=4, i8=8)) table = 'delete_test_table' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', oids=False, values=enumerate('xyz', start=1)) self.assertRaises(pg.DatabaseError, self.db.get, table, 4) r = self.db.get(table, 1) @@ -2521,13 +2521,13 @@ def testDelete(self): s = delete(table, r) self.assertEqual(s, 0) - def testDeleteWithOids(self): + def test_delete_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") delete = self.db.delete get = self.db.get query = self.db.query - self.createTable('test_table', 'n int', oids=True, values=range(1, 7)) + self.create_table('test_table', 'n int', oids=True, values=range(1, 7)) r = dict(n=3) self.assertRaises(pg.ProgrammingError, delete, 'test_table', r) s = get('test_table', 1, 'n') @@ -2617,10 +2617,10 @@ def testDeleteWithOids(self): self.assertEqual(r, 1) self.assertEqual(query(q).getresult()[0], (None, 0)) - def testDeleteWithCompositeKey(self): + def test_delete_with_composite_key(self): query = self.db.query table = 'delete_test_table_1' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', values=enumerate('abc', start=1)) self.assertRaises(KeyError, self.db.delete, table, dict(t='b')) self.assertEqual(self.db.delete(table, dict(n=2)), 1) @@ -2630,7 +2630,7 @@ def testDeleteWithCompositeKey(self): r = query(f'select t from "{table}" where n=3').getresult()[0][0] self.assertEqual(r, 'c') table = 'delete_test_table_2' - self.createTable( + self.create_table( table, 'n integer, m integer, t text, primary key (n, m)', values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) for n in range(3) for m in range(2)]) @@ -2648,11 +2648,11 @@ def testDeleteWithCompositeKey(self): f' order by m').getresult()] self.assertEqual(r, ['f']) - def testDeleteWithQuotedNames(self): + def test_delete_with_quoted_names(self): delete = self.db.delete query = self.db.query table = 'test table for delete()' - self.createTable( + self.create_table( table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(19, 5005, 'Yes!')]) @@ -2667,12 +2667,12 @@ def testDeleteWithQuotedNames(self): r = query(f'select count(*) from "{table}"').getresult() self.assertEqual(r[0][0], 0) - def testDeleteReferenced(self): + def test_delete_referenced(self): delete = self.db.delete query = self.db.query - self.createTable( + self.create_table( 'test_parent', 'n smallint primary key', values=range(3)) - self.createTable( + self.create_table( 'test_child', 'n smallint primary key references test_parent', values=range(3)) q = ("select (select count(*) from test_parent)," @@ -2705,9 +2705,10 @@ def testDeleteReferenced(self): q = "select n from test_parent natural join test_child limit 2" self.assertEqual(query(q).getresult(), [(1,)]) - def testTempCrud(self): + def test_temp_crud(self): table = 'test_temp_table' - self.createTable(table, "n int primary key, t varchar", temporary=True) + self.create_table(table, "n int primary key, t varchar", + temporary=True) self.db.insert(table, dict(n=1, t='one')) self.db.insert(table, dict(n=2, t='too')) self.db.insert(table, dict(n=3, t='three')) @@ -2720,14 +2721,14 @@ def testTempCrud(self): r = self.db.query(f'select n, t from {table} order by 1').getresult() self.assertEqual(r, [(1, 'one'), (3, 'three')]) - def testTruncate(self): + def test_truncate(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, None) self.assertRaises(TypeError, truncate, 42) self.assertRaises(TypeError, truncate, dict(test_table=None)) query = self.db.query - self.createTable('test_table', 'n smallint', - temporary=False, values=[1] * 3) + self.create_table('test_table', 'n smallint', + temporary=False, values=[1] * 3) q = "select count(*) from test_table" r = query(q).getresult()[0][0] self.assertEqual(r, 3) @@ -2741,7 +2742,7 @@ def testTruncate(self): truncate('public.test_table') r = query(q).getresult()[0][0] self.assertEqual(r, 0) - self.createTable('test_table_2', 'n smallint', temporary=True) + self.create_table('test_table_2', 'n smallint', temporary=True) for t in (list, tuple, set): for i in range(3): query("insert into test_table values (1)") @@ -2754,11 +2755,11 @@ def testTruncate(self): r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - def testTruncateRestart(self): + def test_truncate_restart(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, 'test_table', restart='invalid') query = self.db.query - self.createTable('test_table', 'n serial, t text') + self.create_table('test_table', 'n serial, t text') for n in range(3): query("insert into test_table (t) values ('test')") q = "select count(n), min(n), max(n) from test_table" @@ -2779,13 +2780,13 @@ def testTruncateRestart(self): r = query(q).getresult()[0] self.assertEqual(r, (3, 1, 3)) - def testTruncateCascade(self): + def test_truncate_cascade(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid') query = self.db.query - self.createTable('test_parent', 'n smallint primary key', + self.create_table('test_parent', 'n smallint primary key', values=range(3)) - self.createTable('test_child', + self.create_table('test_child', 'n smallint primary key references test_parent (n)', values=range(3)) q = ("select (select count(*) from test_parent)," @@ -2817,12 +2818,12 @@ def testTruncateCascade(self): r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - def testTruncateOnly(self): + def test_truncate_only(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, 'test_table', only='invalid') query = self.db.query - self.createTable('test_parent', 'n smallint') - self.createTable('test_child', 'm smallint) inherits (test_parent') + self.create_table('test_parent', 'n smallint') + self.create_table('test_child', 'm smallint) inherits (test_parent') for n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") @@ -2854,8 +2855,9 @@ def testTruncateOnly(self): self.assertEqual(r, (0, 0)) self.assertRaises(ValueError, truncate, 'test_parent*', only=True) truncate('test_parent*', only=False) - self.createTable('test_parent_2', 'n smallint') - self.createTable('test_child_2', 'm smallint) inherits (test_parent_2') + self.create_table('test_parent_2', 'n smallint') + self.create_table('test_child_2', + 'm smallint) inherits (test_parent_2') for t in '', '_2': for n in range(3): query(f"insert into test_parent{t} (n) values (1)") @@ -2877,11 +2879,11 @@ def testTruncateOnly(self): ['test_parent*', 'test_child'], only=[True, False]) truncate(['test_parent*', 'test_child'], only=[False, True]) - def testTruncateQuoted(self): + def test_truncate_quoted(self): truncate = self.db.truncate query = self.db.query table = "test table for truncate()" - self.createTable(table, 'n smallint', temporary=False, values=[1] * 3) + self.create_table(table, 'n smallint', temporary=False, values=[1] * 3) q = f'select count(*) from "{table}"' r = query(q).getresult()[0][0] self.assertEqual(r, 3) @@ -2897,7 +2899,7 @@ def testTruncateQuoted(self): self.assertEqual(r, 0) # noinspection PyUnresolvedReferences - def testGetAsList(self): + def test_get_as_list(self): get_as_list = self.db.get_as_list self.assertRaises(TypeError, get_as_list) self.assertRaises(TypeError, get_as_list, None) @@ -2908,7 +2910,7 @@ def testGetAsList(self): named = hasattr(r, 'colname') names = [(1, 'Homer'), (2, 'Marge'), (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')] - self.createTable( + self.create_table( table, 'id smallint primary key, name varchar', values=names) r = get_as_list(table) self.assertIsInstance(r, list) @@ -3010,7 +3012,7 @@ def testGetAsList(self): self.assertEqual(t, ('bart',)) # noinspection PyUnresolvedReferences - def testGetAsDict(self): + def test_get_as_dict(self): get_as_dict = self.db.get_as_dict self.assertRaises(TypeError, get_as_dict) self.assertRaises(TypeError, get_as_dict, None) @@ -3023,7 +3025,7 @@ def testGetAsDict(self): named = hasattr(r, 'colname') colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'), (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')] - self.createTable( + self.create_table( table, 'id smallint primary key, rgb char(7), name varchar', values=colors) # keyname must be string, list or tuple @@ -3178,9 +3180,9 @@ def testGetAsDict(self): r = get_as_dict(table, keyname='id') self.assertEqual(r, expected) - def testTransaction(self): + def test_transaction(self): query = self.db.query - self.createTable('test_table', 'n integer', temporary=False) + self.create_table('test_table', 'n integer', temporary=False) self.db.begin() query("insert into test_table values (1)") query("insert into test_table values (2)") @@ -3217,14 +3219,14 @@ def testTransaction(self): query, "insert into test_table values (0)") self.db.abort() - def testTransactionAliases(self): + def test_transaction_aliases(self): self.assertEqual(self.db.begin, self.db.start) self.assertEqual(self.db.commit, self.db.end) self.assertEqual(self.db.rollback, self.db.abort) - def testContextManager(self): + def test_context_manager(self): query = self.db.query - self.createTable('test_table', 'n integer check(n>0)') + self.create_table('test_table', 'n integer check(n>0)') with self.db: query("insert into test_table values (1)") query("insert into test_table values (2)") @@ -3249,9 +3251,9 @@ def testContextManager(self): "select * from test_table order by 1").getresult()] self.assertEqual(r, [1, 2, 5, 7]) - def testBytea(self): + def test_bytea(self): query = self.db.query - self.createTable('bytea_test', 'n smallint primary key, data bytea') + self.create_table('bytea_test', 'n smallint primary key, data bytea') s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n" r = self.db.escape_bytea(s) query('insert into bytea_test values(3, $1)', (r,)) @@ -3267,10 +3269,10 @@ def testBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, s) - def testInsertUpdateGetBytea(self): + def test_insert_update_get_bytea(self): query = self.db.query unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None - self.createTable('bytea_test', 'n smallint primary key, data bytea') + self.create_table('bytea_test', 'n smallint primary key, data bytea') # insert null value r = self.db.insert('bytea_test', n=0, data=None) self.assertIsInstance(r, dict) @@ -3341,8 +3343,8 @@ def testInsertUpdateGetBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, s) - def testUpsertBytea(self): - self.createTable('bytea_test', 'n smallint primary key, data bytea') + def test_upsert_bytea(self): + self.create_table('bytea_test', 'n smallint primary key, data bytea') s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n" r = dict(n=7, data=s) r = self.db.upsert('bytea_test', r) @@ -3363,8 +3365,8 @@ def testUpsertBytea(self): self.assertIn('data', r) self.assertIsNone(r['data']) - def testInsertGetJson(self): - self.createTable('json_test', 'n smallint primary key, data json') + def test_insert_get_json(self): + self.create_table('json_test', 'n smallint primary key, data json') jsondecode = pg.get_jsondecode() # insert null value r = self.db.insert('json_test', n=0, data=None) @@ -3427,8 +3429,8 @@ def testInsertGetJson(self): self.assertIsInstance(r[0][0], str if jsondecode is None else dict) self.assertEqual(r[0][0], r[1][0]) - def testInsertGetJsonb(self): - self.createTable('jsonb_test', + def test_insert_get_jsonb(self): + self.create_table('jsonb_test', 'n smallint primary key, data jsonb') jsondecode = pg.get_jsondecode() # insert null value @@ -3485,9 +3487,9 @@ def testInsertGetJsonb(self): self.assertIsInstance(r['tags'], list) self.assertIsInstance(r['stock'], dict) - def testArray(self): + def test_array(self): returns_arrays = pg.get_array() - self.createTable( + self.create_table( 'arraytest', 'id smallint, i2 smallint[], i4 integer[], i8 bigint[],' ' d numeric[], f4 real[], f8 double precision[], m money[],' @@ -3545,10 +3547,10 @@ def testArray(self): else: self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}') - def testArrayLiteral(self): + def test_array_literal(self): insert = self.db.insert returns_arrays = pg.get_array() - self.createTable('arraytest', 'i int[], t text[]') + self.create_table('arraytest', 'i int[], t text[]') r = dict(i=[1, 2, 3], t=['a', 'b', 'c']) insert('arraytest', r) if returns_arrays: @@ -3565,8 +3567,8 @@ def testArrayLiteral(self): else: self.assertEqual(r['i'], '{1,2,3}') self.assertEqual(r['t'], '{a,b,c}') - L = pg.Literal - r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']")) + Lit = pg.Literal # noqa: N806 + r = dict(i=Lit("ARRAY[1, 2, 3]"), t=Lit("ARRAY['a', 'b', 'c']")) self.db.insert('arraytest', r) if returns_arrays: self.assertEqual(r['i'], [1, 2, 3]) @@ -3577,9 +3579,9 @@ def testArrayLiteral(self): r = dict(i="1, 2, 3", t="'a', 'b', 'c'") self.assertRaises(pg.DataError, self.db.insert, 'arraytest', r) - def testArrayOfIds(self): + def test_array_of_ids(self): array_on = pg.get_array() - self.createTable( + self.create_table( 'arraytest', 'i serial primary key, c cid[], o oid[], x xid[]') r = self.db.get_attnames('arraytest') if self.regtypes: @@ -3601,9 +3603,9 @@ def testArrayOfIds(self): else: self.assertEqual(r['o'], '{21,22,23}') - def testArrayOfText(self): + def test_array_of_text(self): array_on = pg.get_array() - self.createTable('arraytest', 'id serial primary key, data text[]') + self.create_table('arraytest', 'id serial primary key, data text[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'text[]') data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"', @@ -3625,10 +3627,10 @@ def testArrayOfText(self): self.assertIsNone(r['data'][2]) # noinspection PyUnresolvedReferences - def testArrayOfBytea(self): + def test_array_of_bytea(self): array_on = pg.get_array() bytea_escaped = pg.get_bytea_escaped() - self.createTable('arraytest', 'id serial primary key, data bytea[]') + self.create_table('arraytest', 'id serial primary key, data bytea[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'bytea[]') data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"', @@ -3654,8 +3656,8 @@ def testArrayOfBytea(self): else: self.assertNotEqual(r['data'], data) - def testArrayOfJson(self): - self.createTable('arraytest', 'id serial primary key, data json[]') + def test_array_of_json(self): + self.create_table('arraytest', 'id serial primary key, data json[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'json[]') data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')] @@ -3696,8 +3698,8 @@ def testArrayOfJson(self): else: self.assertEqual(r, '{NULL,NULL}') - def testArrayOfJsonb(self): - self.createTable('arraytest', 'id serial primary key, data jsonb[]') + def test_array_of_jsonb(self): + self.create_table('arraytest', 'id serial primary key, data jsonb[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]') data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')] @@ -3739,9 +3741,9 @@ def testArrayOfJsonb(self): self.assertEqual(r, '{NULL,NULL}') # noinspection PyUnresolvedReferences - def testDeepArray(self): + def test_deep_array(self): array_on = pg.get_array() - self.createTable( + self.create_table( 'arraytest', 'id serial primary key, data text[][][]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'text[]') @@ -3760,13 +3762,13 @@ def testDeepArray(self): self.assertTrue(r['data'].startswith('{{{"Hello,')) # noinspection PyUnresolvedReferences - def testInsertUpdateGetRecord(self): + def test_insert_update_get_record(self): query = self.db.query query('create type test_person_type as' ' (name varchar, age smallint, married bool,' ' weight real, salary money)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', + self.create_table('test_person', 'id serial primary key, person test_person_type', oids=False, temporary=False) attnames = self.db.get_attnames('test_person') @@ -3859,12 +3861,12 @@ def testInsertUpdateGetRecord(self): self.assertIsNone(r['person']) # noinspection PyUnresolvedReferences - def testRecordInsertBytea(self): + def test_record_insert_bytea(self): query = self.db.query query('create type test_person_type as' ' (name text, picture bytea)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', 'person test_person_type', + self.create_table('test_person', 'person test_person_type', temporary=False) person_typ = self.db.get_attnames('test_person')['person'] self.assertEqual(person_typ.attnames, @@ -3879,11 +3881,11 @@ def testRecordInsertBytea(self): self.assertEqual(p.picture, person[1]) self.assertIsInstance(p.picture, bytes) - def testRecordInsertJson(self): + def test_record_insert_json(self): query = self.db.query query('create type test_person_type as (name text, data json)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', 'person test_person_type', + self.create_table('test_person', 'person test_person_type', temporary=False) person_typ = self.db.get_attnames('test_person')['person'] self.assertEqual(person_typ.attnames, @@ -3902,12 +3904,12 @@ def testRecordInsertJson(self): self.assertIsInstance(p.data, dict) # noinspection PyUnresolvedReferences - def testRecordLiteral(self): + def test_record_literal(self): query = self.db.query query('create type test_person_type as' ' (name varchar, age smallint)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', 'person test_person_type', + self.create_table('test_person', 'person test_person_type', temporary=False) person_typ = self.db.get_attnames('test_person')['person'] if self.regtypes: @@ -3929,7 +3931,7 @@ def testRecordLiteral(self): self.assertEqual(p.age, 61) self.assertIsInstance(p.age, int) - def testDate(self): + def test_date(self): query = self.db.query for datestyle in ( 'ISO', 'Postgres, MDY', 'Postgres, DMY', @@ -3953,7 +3955,7 @@ def testDate(self): self.assertEqual(r[0], date.max) self.assertEqual(r[1], date.min) - def testTime(self): + def test_time(self): query = self.db.query d = time(15, 9, 26) q = "select $1::time" @@ -3966,7 +3968,7 @@ def testTime(self): self.assertIsInstance(r, time) self.assertEqual(r, d) - def testTimetz(self): + def test_timetz(self): query = self.db.query timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): @@ -3984,7 +3986,7 @@ def testTimetz(self): self.assertIsInstance(r, time) self.assertEqual(r, d) - def testTimestamp(self): + def test_timestamp(self): query = self.db.query for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', 'SQL, MDY', 'SQL, DMY', 'German'): @@ -4018,7 +4020,7 @@ def testTimestamp(self): self.assertEqual(r[0], datetime.max) self.assertEqual(r[1], datetime.min) - def testTimestamptz(self): + def test_timestamptz(self): query = self.db.query timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): @@ -4057,7 +4059,7 @@ def testTimestamptz(self): self.assertEqual(r[0], datetime.max) self.assertEqual(r[1], datetime.min) - def testInterval(self): + def test_interval(self): query = self.db.query for intervalstyle in ( 'sql_standard', 'postgres', 'postgres_verbose', 'iso_8601'): @@ -4077,7 +4079,7 @@ def testInterval(self): self.assertIsInstance(r, timedelta) self.assertEqual(r, d) - def testDateAndTimeArrays(self): + def test_date_and_time_arrays(self): dt = (date(2016, 3, 14), time(15, 9, 26)) q = "select ARRAY[$1::date], ARRAY[$2::time]" r = self.db.query(q, dt).getresult()[0] @@ -4086,7 +4088,7 @@ def testDateAndTimeArrays(self): self.assertIsInstance(r[1], list) self.assertEqual(r[1][0], dt[1]) - def testHstore(self): + def test_hstore(self): try: self.db.query("select 'k=>v'::hstore") except pg.DatabaseError: @@ -4103,14 +4105,14 @@ def testHstore(self): self.assertIsInstance(r, dict) self.assertEqual(r, d) - def testUuid(self): + def test_uuid(self): d = UUID('{12345678-1234-5678-1234-567812345678}') q = 'select $1::uuid' r = self.db.query(q, (d,)).getresult()[0][0] self.assertIsInstance(r, UUID) self.assertEqual(r, d) - def testDbTypesInfo(self): + def test_db_types_info(self): dbtypes = self.db.dbtypes self.assertIsInstance(dbtypes, dict) self.assertNotIn('numeric', dbtypes) @@ -4158,7 +4160,7 @@ def testDbTypesInfo(self): self.assertEqual(typlen.category, 'N') # numeric # noinspection PyUnresolvedReferences - def testDbTypesTypecast(self): + def test_db_types_typecast(self): dbtypes = self.db.dbtypes self.assertIsInstance(dbtypes, dict) self.assertNotIn('int4', dbtypes) @@ -4185,7 +4187,7 @@ def testDbTypesTypecast(self): dbtypes.reset_typecast('circle') self.assertIsNone(dbtypes.get_typecast('circle')) - def testGetSetTypeCast(self): + def test_get_set_type_cast(self): get_typecast = pg.get_typecast set_typecast = pg.set_typecast dbtypes = self.db.dbtypes @@ -4209,7 +4211,7 @@ def testGetSetTypeCast(self): set_typecast('circle', cast_circle) self.assertIs(get_typecast('circle'), cast_circle) - def testNotificationHandler(self): + def test_notification_handler(self): # the notification handler itself is tested separately f = self.db.notification_handler callback = lambda arg_dict: None # noqa: E731 @@ -4286,11 +4288,11 @@ def testNotificationHandler(self): self.db.reopen() self.assertIsNone(handler.db) - def testInserttableFromQuery(self): + def test_inserttable_from_query(self): # use inserttable() to copy from one table to another query = self.db.query - self.createTable('test_table_from', 'n integer, t timestamp') - self.createTable('test_table_to', 'n integer, t timestamp') + self.create_table('test_table_from', 'n integer, t timestamp') + self.create_table('test_table_to', 'n integer, t timestamp') for i in range(1, 4): query("insert into test_table_from values ($1, now())", i) n = self.db.inserttable( @@ -4355,7 +4357,7 @@ def tearDown(self): except pg.InternalError: pass - def testGuessSimpleType(self): + def test_guess_simple_type(self): f = self.adapter.guess_simple_type self.assertEqual(f(pg.Bytea(b'test')), 'bytea') self.assertEqual(f('string'), 'text') @@ -4376,7 +4378,7 @@ def testGuessSimpleType(self): self.assertEqual(list(r.attnames.values()), [ 'text', 'bool', 'int', 'float', 'int[]', 'bool[]']) - def testAdaptQueryTypedList(self): + def test_adapt_query_typed_list(self): format_query = self.adapter.format_query self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), ('int2',)) self.assertRaises( @@ -4416,7 +4418,7 @@ def testAdaptQueryTypedList(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) - def testAdaptQueryTypedListWithTypesAsString(self): + def test_adapt_query_typed_list_with_types_as_string(self): format_query = self.adapter.format_query self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), 'int2') self.assertRaises( @@ -4427,7 +4429,7 @@ def testAdaptQueryTypedListWithTypesAsString(self): self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) - def testAdaptQueryTypedListWithTypesAsClasses(self): + def test_adapt_query_typed_list_with_types_as_classes(self): format_query = self.adapter.format_query self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), (int,)) self.assertRaises( @@ -4438,7 +4440,7 @@ def testAdaptQueryTypedListWithTypesAsClasses(self): self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) - def testAdaptQueryTypedListWithJson(self): + def test_adapt_query_typed_list_with_json(self): format_query = self.adapter.format_query value = {'test': [1, "it's fine", 3]} sql, params = format_query("select %s", (value,), 'json') @@ -4453,7 +4455,7 @@ def testAdaptQueryTypedListWithJson(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) - def testAdaptQueryTypedWithHstore(self): + def test_adapt_query_typed_with_hstore(self): format_query = self.adapter.format_query value = {'one': "it's fine", 'two': 2} sql, params = format_query("select %s", (value,), 'hstore') @@ -4468,7 +4470,7 @@ def testAdaptQueryTypedWithHstore(self): self.assertEqual(sql, "select $1") self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) - def testAdaptQueryTypedWithUuid(self): + def test_adapt_query_typed_with_uuid(self): format_query = self.adapter.format_query value = '12345678-1234-5678-1234-567812345678' sql, params = format_query("select %s", (value,), 'uuid') @@ -4483,7 +4485,7 @@ def testAdaptQueryTypedWithUuid(self): self.assertEqual(sql, "select $1") self.assertEqual(params, ['12345678-1234-5678-1234-567812345678']) - def testAdaptQueryTypedDict(self): + def test_adapt_query_typed_dict(self): format_query = self.adapter.format_query self.assertRaises( TypeError, format_query, @@ -4527,7 +4529,7 @@ def testAdaptQueryTypedDict(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) - def testAdaptQueryUntypedList(self): + def test_adapt_query_untyped_list(self): format_query = self.adapter.format_query values = (3, 7.5, 'hello', True) sql, params = format_query("select %s,%s,%s,%s", values) @@ -4552,21 +4554,21 @@ def testAdaptQueryUntypedList(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) - def testAdaptQueryUntypedListWithJson(self): + def test_adapt_query_untyped_list_with_json(self): format_query = self.adapter.format_query value = pg.Json({'test': [1, "it's fine", 3]}) sql, params = format_query("select %s", (value,)) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) - def testAdaptQueryUntypedWithHstore(self): + def test_adapt_query_untyped_with_hstore(self): format_query = self.adapter.format_query value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", (value,)) self.assertEqual(sql, "select $1") self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) - def testAdaptQueryUntypedDict(self): + def test_adapt_query_untyped_dict(self): format_query = self.adapter.format_query values = dict(i=3, f=7.5, t='hello', b=True) sql, params = format_query( @@ -4593,7 +4595,7 @@ def testAdaptQueryUntypedDict(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) - def testAdaptQueryInlineList(self): + def test_adapt_query_inline_list(self): format_query = self.adapter.format_query values = (3, 7.5, 'hello', True) sql, params = format_query("select %s,%s,%s,%s", values, inline=True) @@ -4621,7 +4623,7 @@ def testAdaptQueryInlineList(self): sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) - def testAdaptQueryInlineListWithJson(self): + def test_adapt_query_inline_list_with_json(self): format_query = self.adapter.format_query value = pg.Json({'test': [1, "it's fine", 3]}) sql, params = format_query("select %s", (value,), inline=True) @@ -4629,7 +4631,7 @@ def testAdaptQueryInlineListWithJson(self): sql, "select '{\"test\": [1, \"it''s fine\", 3]}'::json") self.assertEqual(params, []) - def testAdaptQueryInlineListWithHstore(self): + def test_adapt_query_inline_list_with_hstore(self): format_query = self.adapter.format_query value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", (value,), inline=True) @@ -4637,7 +4639,7 @@ def testAdaptQueryInlineListWithHstore(self): sql, "select 'one=>\"it''s fine\",two=>2'::hstore") self.assertEqual(params, []) - def testAdaptQueryInlineDict(self): + def test_adapt_query_inline_dict(self): format_query = self.adapter.format_query values = dict(i=3, f=7.5, t='hello', b=True) sql, params = format_query( @@ -4668,7 +4670,7 @@ def testAdaptQueryInlineDict(self): sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) - def testAdaptQueryWithPgRepr(self): + def test_adapt_query_with_pg_repr(self): format_query = self.adapter.format_query self.assertRaises(TypeError, format_query, '%s', object(), inline=True) @@ -4739,7 +4741,7 @@ def tearDown(self): self.doCleanups() self.db.close() - def testGetTables(self): + def test_get_tables(self): tables = self.db.get_tables() for num_schema in range(5): if num_schema: @@ -4750,7 +4752,7 @@ def testGetTables(self): schema + ".t" + str(num_schema)): self.assertIn(t, tables) - def testGetAttnames(self): + def test_get_attnames(self): get_attnames = self.db.get_attnames query = self.db.query result = {'d': 'int', 'n': 'int'} @@ -4774,10 +4776,10 @@ def testGetAttnames(self): r = get_attnames("t3m") self.assertEqual(r, result_m) - def testGet(self): + def test_get(self): get = self.db.get query = self.db.query - PrgError = pg.ProgrammingError + PrgError = pg.ProgrammingError # noqa: N806 self.assertEqual(get("t", 1, 'n')['d'], 0) self.assertEqual(get("t0", 1, 'n')['d'], 0) self.assertEqual(get("public.t", 1, 'n')['d'], 0) @@ -4798,7 +4800,7 @@ def testGet(self): self.assertEqual(get("t", 1, 'n')['d'], 1) self.assertEqual(get("s4.t4", 1, 'n')['d'], 4) - def testMunging(self): + def test_munging(self): get = self.db.get query = self.db.query r = get("t", 1, 'n') @@ -4819,7 +4821,7 @@ def testMunging(self): else: self.assertNotIn('oid(t)', r) - def testQueryInformationSchema(self): + def test_query_information_schema(self): q = "column_name" if self.db.server_version < 110000: q += "::text" # old version does not have sql_identifier array @@ -4853,30 +4855,30 @@ def send_queries(self): self.db.query("select 1") self.db.query("select 2") - def testDebugDefault(self): + def test_debug_default(self): if debug: self.assertEqual(self.db.debug, debug) else: self.assertIsNone(self.db.debug) - def testDebugIsFalse(self): + def test_debug_is_false(self): self.db.debug = False self.send_queries() self.assertEqual(self.get_output(), "") - def testDebugIsTrue(self): + def test_debug_is_true(self): self.db.debug = True self.send_queries() self.assertEqual(self.get_output(), "select 1\nselect 2\n") - def testDebugIsString(self): + def test_debug_is_string(self): self.db.debug = "Test with string: %s." self.send_queries() self.assertEqual( self.get_output(), "Test with string: select 1.\nTest with string: select 2.\n") - def testDebugIsFileLike(self): + def test_debug_is_file_like(self): with tempfile.TemporaryFile('w+') as debug_file: self.db.debug = debug_file self.send_queries() @@ -4885,7 +4887,7 @@ def testDebugIsFileLike(self): self.assertEqual(output, "select 1\nselect 2\n") self.assertEqual(self.get_output(), "") - def testDebugIsCallable(self): + def test_debug_is_callable(self): output = [] self.db.debug = output.append self.db.query("select 1") @@ -4893,7 +4895,7 @@ def testDebugIsCallable(self): self.assertEqual(output, ["select 1", "select 2"]) self.assertEqual(self.get_output(), "") - def testDebugMultipleArgs(self): + def test_debug_multiple_args(self): output = [] self.db.debug = output.append args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]] @@ -4905,7 +4907,7 @@ def testDebugMultipleArgs(self): class TestMemoryLeaks(unittest.TestCase): """Test that the DB class does not leak memory.""" - def getLeaks(self, fut): + def get_leaks(self, fut): ids = set() objs = [] add_ids = ids.update @@ -4918,20 +4920,20 @@ def getLeaks(self, fut): objs[:] = [obj for obj in objs if id(obj) not in ids] self.assertEqual(len(objs), 0) - def testLeaksWithClose(self): + def test_leaks_with_close(self): def fut(): db = DB() db.query("select $1::int as r", 42).dictresult() db.close() - self.getLeaks(fut) + self.get_leaks(fut) - def testLeaksWithoutClose(self): + def test_leaks_without_close(self): def fut(): db = DB() db.query("select $1::int as r", 42).dictresult() - self.getLeaks(fut) + self.get_leaks(fut) if __name__ == '__main__': diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 914450f5..5a49e9d2 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -20,59 +20,59 @@ class TestHasConnect(unittest.TestCase): """Test existence of basic pg module functions.""" - def testhasPgError(self): + def testhas_pg_error(self): self.assertTrue(issubclass(pg.Error, Exception)) - def testhasPgWarning(self): + def testhas_pg_warning(self): self.assertTrue(issubclass(pg.Warning, Exception)) - def testhasPgInterfaceError(self): + def testhas_pg_interface_error(self): self.assertTrue(issubclass(pg.InterfaceError, pg.Error)) - def testhasPgDatabaseError(self): + def testhas_pg_database_error(self): self.assertTrue(issubclass(pg.DatabaseError, pg.Error)) - def testhasPgInternalError(self): + def testhas_pg_internal_error(self): self.assertTrue(issubclass(pg.InternalError, pg.DatabaseError)) - def testhasPgOperationalError(self): + def testhas_pg_operational_error(self): self.assertTrue(issubclass(pg.OperationalError, pg.DatabaseError)) - def testhasPgProgrammingError(self): + def testhas_pg_programming_error(self): self.assertTrue(issubclass(pg.ProgrammingError, pg.DatabaseError)) - def testhasPgIntegrityError(self): + def testhas_pg_integrity_error(self): self.assertTrue(issubclass(pg.IntegrityError, pg.DatabaseError)) - def testhasPgDataError(self): + def testhas_pg_data_error(self): self.assertTrue(issubclass(pg.DataError, pg.DatabaseError)) - def testhasPgNotSupportedError(self): + def testhas_pg_not_supported_error(self): self.assertTrue(issubclass(pg.NotSupportedError, pg.DatabaseError)) - def testhasPgInvalidResultError(self): + def testhas_pg_invalid_result_error(self): self.assertTrue(issubclass(pg.InvalidResultError, pg.DataError)) - def testhasPgNoResultError(self): + def testhas_pg_no_result_error(self): self.assertTrue(issubclass(pg.NoResultError, pg.InvalidResultError)) - def testhasPgMultipleResultsError(self): + def testhas_pg_multiple_results_error(self): self.assertTrue( issubclass(pg.MultipleResultsError, pg.InvalidResultError)) - def testhasConnect(self): + def testhas_connect(self): self.assertTrue(callable(pg.connect)) - def testhasEscapeString(self): + def testhas_escape_string(self): self.assertTrue(callable(pg.escape_string)) - def testhasEscapeBytea(self): + def testhas_escape_bytea(self): self.assertTrue(callable(pg.escape_bytea)) - def testhasUnescapeBytea(self): + def testhas_unescape_bytea(self): self.assertTrue(callable(pg.unescape_bytea)) - def testDefHost(self): + def test_def_host(self): d0 = pg.get_defhost() d1 = 'pgtesthost' pg.set_defhost(d1) @@ -80,7 +80,7 @@ def testDefHost(self): pg.set_defhost(d0) self.assertEqual(pg.get_defhost(), d0) - def testDefPort(self): + def test_def_port(self): d0 = pg.get_defport() d1 = 1234 pg.set_defport(d1) @@ -92,7 +92,7 @@ def testDefPort(self): d0 = None self.assertEqual(pg.get_defport(), d0) - def testDefOpt(self): + def test_def_opt(self): d0 = pg.get_defopt() d1 = '-h pgtesthost -p 1234' pg.set_defopt(d1) @@ -100,7 +100,7 @@ def testDefOpt(self): pg.set_defopt(d0) self.assertEqual(pg.get_defopt(), d0) - def testDefBase(self): + def test_def_base(self): d0 = pg.get_defbase() d1 = 'pgtestdb' pg.set_defbase(d1) @@ -108,7 +108,7 @@ def testDefBase(self): pg.set_defbase(d0) self.assertEqual(pg.get_defbase(), d0) - def testPqlibVersion(self): + def test_pqlib_version(self): # noinspection PyUnresolvedReferences v = pg.get_pqlib_version() self.assertIsInstance(v, int) @@ -216,7 +216,7 @@ class TestParseArray(unittest.TestCase): ('[3:5]={{1,2,3},{4,5,6}}', int, ValueError), ('[1:1][-2:-1][3:5]={{1,2,3},{4,5,6}}', int, ValueError)] - def testParserParams(self): + def test_parser_params(self): f = pg.cast_array self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, None) @@ -235,13 +235,13 @@ def testParserParams(self): self.assertEqual(f('{}', str), []) self.assertEqual(f('{}', str, b';'), []) - def testParserSimple(self): + def test_parser_simple(self): r = pg.cast_array('{a,b,c}') self.assertIsInstance(r, list) self.assertEqual(len(r), 3) self.assertEqual(r, ['a', 'b', 'c']) - def testParserNested(self): + def test_parser_nested(self): f = pg.cast_array r = f('{{a,b,c}}') self.assertIsInstance(r, list) @@ -273,7 +273,7 @@ def testParserNested(self): r = r[0] self.assertEqual(r, 'abc') - def testParserTooDeeplyNested(self): + def test_parser_too_deeply_nested(self): f = pg.cast_array for n in 3, 5, 9, 12, 16, 32, 64, 256: r = '{' * n + 'a,b,c' + '}' * n @@ -288,7 +288,7 @@ def testParserTooDeeplyNested(self): self.assertEqual(len(r), 3) self.assertEqual(r, ['a', 'b', 'c']) - def testParserCast(self): + def test_parser_cast(self): f = pg.cast_array self.assertEqual(f('{1}'), ['1']) self.assertEqual(f('{1}', None), ['1']) @@ -303,7 +303,7 @@ def cast(s): return f'{s} is ok' self.assertEqual(f('{a}', cast), ['a is ok']) - def testParserDelim(self): + def test_parser_delim(self): f = pg.cast_array self.assertEqual(f('{1,2}'), ['1', '2']) self.assertEqual(f('{1,2}', delim=b','), ['1', '2']) @@ -311,7 +311,7 @@ def testParserDelim(self): self.assertEqual(f('{1;2}', delim=b';'), ['1', '2']) self.assertEqual(f('{1,2}', delim=b';'), ['1,2']) - def testParserWithData(self): + def test_parser_with_data(self): f = pg.cast_array for string, cast, expected in self.test_strings: if expected is ValueError: @@ -319,7 +319,7 @@ def testParserWithData(self): else: self.assertEqual(f(string, cast), expected) - def testParserWithoutCast(self): + def test_parser_without_cast(self): f = pg.cast_array for string, cast, expected in self.test_strings: @@ -330,7 +330,7 @@ def testParserWithoutCast(self): else: self.assertEqual(f(string), expected) - def testParserWithDifferentDelimiter(self): + def test_parser_with_different_delimiter(self): f = pg.cast_array def replace_comma(value): @@ -491,7 +491,7 @@ class TestParseRecord(unittest.TestCase): ('(fuzzy dice,"42","1.9375")', (str, int, float), ('fuzzy dice', 42, 1.9375))] - def testParserParams(self): + def test_parser_params(self): f = pg.cast_record self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, None) @@ -510,20 +510,20 @@ def testParserParams(self): self.assertEqual(f('()', str), (None,)) self.assertEqual(f('()', str, b';'), (None,)) - def testParserSimple(self): + def test_parser_simple(self): r = pg.cast_record('(a,b,c)') self.assertIsInstance(r, tuple) self.assertEqual(len(r), 3) self.assertEqual(r, ('a', 'b', 'c')) - def testParserNested(self): + def test_parser_nested(self): f = pg.cast_record self.assertRaises(ValueError, f, '((a,b,c))') self.assertRaises(ValueError, f, '((a,b),(c,d))') self.assertRaises(ValueError, f, '((a),(b),(c))') self.assertRaises(ValueError, f, '(((((((abc)))))))') - def testParserManyElements(self): + def test_parser_many_elements(self): f = pg.cast_record for n in 3, 5, 9, 12, 16, 32, 64, 256: r = ','.join(map(str, range(n))) @@ -531,7 +531,7 @@ def testParserManyElements(self): r = f(r, int) self.assertEqual(r, tuple(range(n))) - def testParserCastUniform(self): + def test_parser_cast_uniform(self): f = pg.cast_record self.assertEqual(f('(1)'), ('1',)) self.assertEqual(f('(1)', None), ('1',)) @@ -546,7 +546,7 @@ def cast(s): return f'{s} is ok' self.assertEqual(f('(a)', cast), ('a is ok',)) - def testParserCastNonUniform(self): + def test_parser_cast_non_uniform(self): f = pg.cast_record self.assertEqual(f('(1)', []), ('1',)) self.assertEqual(f('(1)', [None]), ('1',)) @@ -583,7 +583,7 @@ def cast2(s): f('(1,2,3,4,5,6)', [int, float, str, None, cast1, cast2]), (1, 2.0, '3', '4', '5 is ok', 'and 6 is ok, too')) - def testParserDelim(self): + def test_parser_delim(self): f = pg.cast_record self.assertEqual(f('(1,2)'), ('1', '2')) self.assertEqual(f('(1,2)', delim=b','), ('1', '2')) @@ -591,7 +591,7 @@ def testParserDelim(self): self.assertEqual(f('(1;2)', delim=b';'), ('1', '2')) self.assertEqual(f('(1,2)', delim=b';'), ('1,2',)) - def testParserWithData(self): + def test_parser_with_data(self): f = pg.cast_record for string, cast, expected in self.test_strings: if expected is ValueError: @@ -599,7 +599,7 @@ def testParserWithData(self): else: self.assertEqual(f(string, cast), expected) - def testParserWithoutCast(self): + def test_parser_without_cast(self): f = pg.cast_record for string, cast, expected in self.test_strings: @@ -610,7 +610,7 @@ def testParserWithoutCast(self): else: self.assertEqual(f(string), expected) - def testParserWithDifferentDelimiter(self): + def test_parser_with_different_delimiter(self): f = pg.cast_record def replace_comma(value): @@ -665,7 +665,7 @@ class TestParseHStore(unittest.TestCase): (r'k\=\>v=>"k=>v"', {'k=>v': 'k=>v'}), ('a\\,b=>a,b=>a', {'a,b': 'a', 'b': 'a'})] - def testParser(self): + def test_parser(self): f = pg.cast_hstore self.assertRaises(TypeError, f) @@ -842,7 +842,7 @@ class TestCastInterval(unittest.TestCase): '@ 10 mons 3 days -3 hours -55 mins -5.999993 secs ago', 'P-10M-3DT3H55M5.999993S'))] - def testCastInterval(self): + def test_cast_interval(self): for result, values in self.intervals: f = pg.cast_interval years, mons, days, hours, mins, secs, usecs = result @@ -864,7 +864,7 @@ class TestEscapeFunctions(unittest.TestCase): """ - def testEscapeString(self): + def test_escape_string(self): f = pg.escape_string r = f(b'plain') self.assertIsInstance(r, bytes) @@ -876,7 +876,7 @@ def testEscapeString(self): self.assertIsInstance(r, str) self.assertEqual(r, "that''s cheese") - def testEscapeBytea(self): + def test_escape_bytea(self): f = pg.escape_bytea r = f(b'plain') self.assertIsInstance(r, bytes) @@ -888,7 +888,7 @@ def testEscapeBytea(self): self.assertIsInstance(r, str) self.assertEqual(r, "that''s cheese") - def testUnescapeBytea(self): + def test_unescape_bytea(self): f = pg.unescape_bytea r = f(b'plain') self.assertIsInstance(r, bytes) @@ -916,10 +916,10 @@ class TestConfigFunctions(unittest.TestCase): """ - def testGetDatestyle(self): + def test_get_datestyle(self): self.assertIsNone(pg.get_datestyle()) - def testSetDatestyle(self): + def test_set_datestyle(self): datestyle = pg.get_datestyle() try: pg.set_datestyle('ISO, YMD') @@ -939,12 +939,12 @@ def testSetDatestyle(self): finally: pg.set_datestyle(datestyle) - def testGetDecimalPoint(self): + def test_get_decimal_point(self): r = pg.get_decimal_point() self.assertIsInstance(r, str) self.assertEqual(r, '.') - def testSetDecimalPoint(self): + def test_set_decimal_point(self): point = pg.get_decimal_point() try: pg.set_decimal_point('*') @@ -957,11 +957,11 @@ def testSetDecimalPoint(self): self.assertIsInstance(r, str) self.assertEqual(r, point) - def testGetDecimal(self): + def test_get_decimal(self): r = pg.get_decimal() self.assertIs(r, pg.Decimal) - def testSetDecimal(self): + def test_set_decimal(self): decimal_class = pg.Decimal try: pg.set_decimal(int) @@ -972,12 +972,12 @@ def testSetDecimal(self): r = pg.get_decimal() self.assertIs(r, decimal_class) - def testGetBool(self): + def test_get_bool(self): r = pg.get_bool() self.assertIsInstance(r, bool) self.assertIs(r, True) - def testSetBool(self): + def test_set_bool(self): use_bool = pg.get_bool() try: pg.set_bool(False) @@ -995,12 +995,12 @@ def testSetBool(self): self.assertIsInstance(r, bool) self.assertIs(r, use_bool) - def testGetByteaEscaped(self): + def test_get_bytea_escaped(self): r = pg.get_bytea_escaped() self.assertIsInstance(r, bool) self.assertIs(r, False) - def testSetByteaEscaped(self): + def test_set_bytea_escaped(self): bytea_escaped = pg.get_bytea_escaped() try: pg.set_bytea_escaped(True) @@ -1018,12 +1018,12 @@ def testSetByteaEscaped(self): self.assertIsInstance(r, bool) self.assertIs(r, bytea_escaped) - def testGetJsondecode(self): + def test_get_jsondecode(self): r = pg.get_jsondecode() self.assertTrue(callable(r)) self.assertIs(r, json.loads) - def testSetJsondecode(self): + def test_set_jsondecode(self): jsondecode = pg.get_jsondecode() try: pg.set_jsondecode(None) @@ -1042,7 +1042,7 @@ def testSetJsondecode(self): class TestModuleConstants(unittest.TestCase): """Test the existence of the documented module constants.""" - def testVersion(self): + def test_version(self): v = pg.version self.assertIsInstance(v, str) # make sure the version conforms to PEP440 diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index 039ca51f..afe48a21 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -32,7 +32,7 @@ def connect(): class TestModuleConstants(unittest.TestCase): """Test the existence of the documented module constants.""" - def testLargeObjectIntConstants(self): + def test_large_object_int_constants(self): names = 'INV_READ INV_WRITE SEEK_SET SEEK_CUR SEEK_END'.split() for name in names: try: @@ -53,7 +53,7 @@ def tearDown(self): self.c.query('rollback') self.c.close() - def assertIsLargeObject(self, obj): + def assertIsLargeObject(self, obj): # noqa: N802 self.assertIsNotNone(obj) self.assertTrue(hasattr(obj, 'open')) self.assertTrue(hasattr(obj, 'close')) @@ -66,14 +66,14 @@ def assertIsLargeObject(self, obj): self.assertIsInstance(obj.error, str) self.assertFalse(obj.error) - def testLoCreate(self): + def test_lo_create(self): large_object = self.c.locreate(pg.INV_READ | pg.INV_WRITE) try: self.assertIsLargeObject(large_object) finally: del large_object - def testGetLo(self): + def test_get_lo(self): large_object = self.c.locreate(pg.INV_READ | pg.INV_WRITE) try: self.assertIsLargeObject(large_object) @@ -103,7 +103,7 @@ def testGetLo(self): self.assertIsInstance(r, bytes) self.assertEqual(r, data) - def testLoImport(self): + def test_lo_import(self): if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_import.txt' @@ -164,24 +164,24 @@ def tearDown(self): pass self.pgcnx.close() - def testClassName(self): + def test_class_name(self): self.assertEqual(self.obj.__class__.__name__, 'LargeObject') - def testModuleName(self): + def test_module_name(self): self.assertEqual(self.obj.__class__.__module__, 'pg') - def testOid(self): + def test_oid(self): self.assertIsInstance(self.obj.oid, int) self.assertNotEqual(self.obj.oid, 0) - def testPgcn(self): + def test_pgcn(self): self.assertIs(self.obj.pgcnx, self.pgcnx) - def testError(self): + def test_error(self): self.assertIsInstance(self.obj.error, str) self.assertEqual(self.obj.error, '') - def testStr(self): + def test_str(self): self.obj.open(pg.INV_WRITE) data = b'some object to be printed' self.obj.write(data) @@ -192,11 +192,11 @@ def testStr(self): r = str(self.obj) self.assertEqual(r, f'Closed large object, oid {oid}') - def testRepr(self): + def test_repr(self): r = repr(self.obj) self.assertTrue(r.startswith(' Date: Sat, 2 Sep 2023 01:10:05 +0200 Subject: [PATCH 142/194] Minor improvements using more ruff specific linting --- docs/contents/changelog.rst | 1 + docs/contents/pgdb/types.rst | 6 +-- pg.py | 39 +++++++++--------- pgdb.py | 69 ++++++++++++++++---------------- pyproject.toml | 1 + setup.py | 6 +-- tests/dbapi20.py | 8 ++-- tests/test_classic_connection.py | 9 +++-- tests/test_classic_functions.py | 9 +++-- tests/test_dbapi20.py | 5 ++- tests/test_dbapi20_copy.py | 14 ++++--- 11 files changed, 88 insertions(+), 79 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 67408993..d240daa2 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -7,6 +7,7 @@ Version 6.0 (to be released) and PostgreSQL older than version 10 (released October 2017). - Removed deprecated function `pg.pgnotify()`. - Removed the deprecated method `ntuples()` of the `pg.Query` object. +- Renamed `pgdb.Type` to `pgdb.DbType` to avoid confusion with `typing.Type`. - Modernized code and tools for development, testing, linting and building. Version 5.2.5 (2023-08-28) diff --git a/docs/contents/pgdb/types.rst b/docs/contents/pgdb/types.rst index f28e23f7..d739df32 100644 --- a/docs/contents/pgdb/types.rst +++ b/docs/contents/pgdb/types.rst @@ -101,15 +101,15 @@ Example for using a type constructor:: Type objects ------------ -.. class:: Type +.. class:: DbType The :attr:`Cursor.description` attribute returns information about each of the result columns of a query. The *type_code* must compare equal to one -of the :class:`Type` objects defined below. Type objects can be equal to +of the :class:`DbType` objects defined below. Type objects can be equal to more than one type code (e.g. :class:`DATETIME` is equal to the type codes for ``date``, ``time`` and ``timestamp`` columns). -The pgdb module exports the following :class:`Type` objects as part of the +The pgdb module exports the following :class:`DbType` objects as part of the DB-API 2 standard: .. object:: STRING diff --git a/pg.py b/pg.py index f8dfb1be..25dc16e7 100644 --- a/pg.py +++ b/pg.py @@ -20,6 +20,22 @@ For a DB-API 2 compliant interface use the newer pgdb module. """ +import select +import weakref +from collections import OrderedDict, namedtuple +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from functools import lru_cache, partial +from inspect import signature +from json import dumps as jsonencode +from json import loads as jsondecode +from math import isinf, isnan +from operator import itemgetter +from re import compile as regex +from types import MappingProxyType +from typing import ClassVar, Dict, List, Mapping, Type, Union +from uuid import UUID + try: from _pg import version except ImportError as e: # noqa: F841 @@ -149,21 +165,6 @@ 'set_jsondecode', 'set_query_helpers', 'set_typecast', 'version', '__version__'] -import select -import weakref -from collections import OrderedDict, namedtuple -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from functools import lru_cache, partial -from inspect import signature -from json import dumps as jsonencode -from json import loads as jsondecode -from math import isinf, isnan -from operator import itemgetter -from re import compile as regex -from typing import Dict, List, Union # noqa: F401 -from uuid import UUID - # Auxiliary classes and functions that are independent of a DB connection: def get_args(func): @@ -239,7 +240,7 @@ class _SimpleTypes(dict): The corresponding Python types and simple names are also mapped. """ - _type_aliases = { + _type_aliases: Mapping[str, List[Union[str, type]]] = MappingProxyType({ 'bool': [bool], 'bytea': [Bytea], 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', @@ -251,13 +252,13 @@ class _SimpleTypes(dict): 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], 'num': ['numeric', Decimal], 'money': [], 'text': ['bpchar', 'char', 'name', 'varchar', bytes, str] - } # type: Dict[str, List[Union[str, type]]] + }) # noinspection PyMissingConstructor def __init__(self): """Initialize type mapping.""" for typ, keys in self._type_aliases.items(): - keys = [typ] + keys + keys = [typ, *keys] for key in keys: self[key] = typ if isinstance(key, str): @@ -969,7 +970,7 @@ class Typecasts(dict): # the default cast functions # (str functions are ignored but have been added for faster access) - defaults = { + defaults: ClassVar[Dict[str, Type]] = { 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, diff --git a/pgdb.py b/pgdb.py index 5752ac4d..00e57f02 100644 --- a/pgdb.py +++ b/pgdb.py @@ -64,6 +64,20 @@ connection.close() # close the connection """ +from collections import namedtuple +from collections.abc import Iterable +from datetime import date, datetime, time, timedelta +from decimal import Decimal as StdDecimal +from functools import lru_cache, partial +from inspect import signature +from json import dumps as jsonencode +from json import loads as jsondecode +from math import isinf, isnan +from re import compile as regex +from time import localtime +from typing import ClassVar, Dict, Type +from uuid import UUID as Uuid # noqa: N811 + try: from _pg import version except ImportError as e: # noqa: F841 @@ -137,19 +151,6 @@ 'get_typecast', 'set_typecast', 'reset_typecast', 'version', '__version__'] -from collections import namedtuple -from collections.abc import Iterable -from datetime import date, datetime, time, timedelta -from decimal import Decimal as StdDecimal -from functools import lru_cache, partial -from inspect import signature -from json import dumps as jsonencode -from json import loads as jsondecode -from math import isinf, isnan -from re import compile as regex -from time import localtime -from uuid import UUID as Uuid # noqa: N811 - Decimal = StdDecimal @@ -417,7 +418,7 @@ class Typecasts(dict): # the default cast functions # (str functions are ignored but have been added for faster access) - defaults = { + defaults: ClassVar[Dict[str, Type]] = { 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, @@ -1579,7 +1580,7 @@ def connect(dsn=None, # *** Types Handling *** -class Type(frozenset): +class DbType(frozenset): """Type class for a couple of PostgreSQL data types. PostgreSQL is object-oriented: types are dynamic. @@ -1651,30 +1652,30 @@ def __ne__(self, other): # Mandatory type objects defined by DB-API 2 specs: -STRING = Type('char bpchar name text varchar') -BINARY = Type('bytea') -NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money') -DATETIME = Type('date time timetz timestamp timestamptz interval' +STRING = DbType('char bpchar name text varchar') +BINARY = DbType('bytea') +NUMBER = DbType('int2 int4 serial int8 float4 float8 numeric money') +DATETIME = DbType('date time timetz timestamp timestamptz interval' ' abstime reltime') # these are very old -ROWID = Type('oid') +ROWID = DbType('oid') # Additional type objects (more specific): -BOOL = Type('bool') -SMALLINT = Type('int2') -INTEGER = Type('int2 int4 int8 serial') -LONG = Type('int8') -FLOAT = Type('float4 float8') -NUMERIC = Type('numeric') -MONEY = Type('money') -DATE = Type('date') -TIME = Type('time timetz') -TIMESTAMP = Type('timestamp timestamptz') -INTERVAL = Type('interval') -UUID = Type('uuid') -HSTORE = Type('hstore') -JSON = Type('json jsonb') +BOOL = DbType('bool') +SMALLINT = DbType('int2') +INTEGER = DbType('int2 int4 int8 serial') +LONG = DbType('int8') +FLOAT = DbType('float4 float8') +NUMERIC = DbType('numeric') +MONEY = DbType('money') +DATE = DbType('date') +TIME = DbType('time timetz') +TIMESTAMP = DbType('timestamp timestamptz') +INTERVAL = DbType('interval') +UUID = DbType('uuid') +HSTORE = DbType('hstore') +JSON = DbType('json jsonb') # Type object for arrays (also equate to their base types): diff --git a/pyproject.toml b/pyproject.toml index 9603b825..382b09ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ select = [ "N", # pep8-naming "UP", # pyupgrade "D", # pydocstyle + "RUF", # ruff ] exclude = [ "__pycache__", diff --git a/setup.py b/setup.py index 09c6e2f8..a52f315d 100755 --- a/setup.py +++ b/setup.py @@ -67,15 +67,15 @@ class build_pg_ext(build_ext): # noqa: N801 description = "build the PyGreSQL C extension" - user_options = build_ext.user_options + [ + user_options = [*build_ext.user_options, # noqa: RUF012 ('strict', None, "count all compiler warnings as errors"), ('memory-size', None, "enable memory size function"), ('no-memory-size', None, "disable memory size function")] - boolean_options = build_ext.boolean_options + [ + boolean_options = [*build_ext.boolean_options, # noqa: RUF012 'strict', 'memory-size'] - negative_opt = { + negative_opt = { # noqa: RUF012 'no-memory-size': 'memory-size'} def get_compiler(self): diff --git a/tests/dbapi20.py b/tests/dbapi20.py index bb913475..12a7647b 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -11,7 +11,7 @@ import time import unittest -from typing import Any, Dict, Tuple +from typing import Any, Mapping, Tuple class DatabaseAPI20Test(unittest.TestCase): @@ -41,7 +41,7 @@ class mytest(dbapi20.DatabaseAPI20Test): # method is to be found driver: Any = None connect_args: Tuple = () # List of arguments to pass to connect - connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect + connect_kw_args: Mapping[str, Any] = {} # Keyword arguments for connect table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables ddl1 = f'create table {table_prefix}booze (name varchar(20))' @@ -502,14 +502,14 @@ def test_next(self): finally: con.close() - samples = [ + samples = ( 'Carlton Cold', 'Carlton Draft', 'Mountain Goat', 'Redback', 'Victoria Bitter', 'XXXX' - ] + ) def _populate(self): """Return a list of SQL commands to setup the DB for fetching tests.""" diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index ed31bed8..440142c7 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -16,6 +16,7 @@ from collections import namedtuple from collections.abc import Iterable from decimal import Decimal +from typing import Sequence, Tuple import pg # the module under test @@ -1743,7 +1744,7 @@ def tearDown(self): self.c.query("truncate table test") self.c.close() - data = [ + data: Sequence[Tuple] = [ (-1, -1, -1, True, '1492-10-12', '08:30:00', -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), (0, 0, 0, False, '1607-04-14', '09:00:00', @@ -1825,7 +1826,7 @@ def test_inserttable_from_tuple_of_lists(self): self.assertEqual(self.get_back(), self.data) def test_inserttable_with_different_row_sizes(self): - data = self.data[:-1] + [self.data[-1][:-1]] + data = [*self.data[:-1], (self.data[-1][:-1],)] try: self.c.inserttable('test', data) except TypeError as e: @@ -2107,10 +2108,10 @@ def test_insert_table_big_row_size(self): def test_insert_table_small_int_overflow(self): rest_row = self.data[2][1:] - data = [(32000,) + rest_row] + data = [(32000, *rest_row)] self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - data = [(33000,) + rest_row] + data = [(33000, *rest_row)] try: self.c.inserttable('test', data) except ValueError as e: diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 5a49e9d2..5babc816 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -13,6 +13,7 @@ import re import unittest from datetime import timedelta +from typing import Any, Sequence, Tuple, Type import pg # the module under test @@ -119,7 +120,7 @@ def test_pqlib_version(self): class TestParseArray(unittest.TestCase): """Test the array parser.""" - test_strings = [ + test_strings: Sequence[Tuple[str, Type, Any]] = [ ('', str, ValueError), ('{}', None, []), ('{}', str, []), @@ -353,7 +354,7 @@ def replace_comma(value): class TestParseRecord(unittest.TestCase): """Test the record parser.""" - test_strings = [ + test_strings: Sequence[Tuple[str, Type, Any]] = [ ('', None, ValueError), ('', str, ValueError), ('(', None, ValueError), @@ -634,7 +635,7 @@ def replace_comma(value): class TestParseHStore(unittest.TestCase): """Test the hstore parser.""" - test_strings = [ + test_strings: Sequence[Tuple[str, Any]] = [ ('', {}), ('=>', ValueError), ('""=>', ValueError), @@ -683,7 +684,7 @@ def test_parser(self): class TestCastInterval(unittest.TestCase): """Test the interval typecast function.""" - intervals = [ + intervals: Sequence[Tuple[Tuple[int, ...], Tuple[str, ...]]] = [ ((0, 0, 0, 1, 0, 0, 0), ('1:00:00', '01:00:00', '@ 1 hour', 'PT1H')), ((0, 0, 0, -1, 0, 0, 0), diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 8ea52a7b..9fd00165 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -3,6 +3,7 @@ import gc import unittest from datetime import date, datetime, time, timedelta, timezone +from typing import Any, Mapping from uuid import UUID as Uuid # noqa: N811 import pgdb @@ -25,7 +26,7 @@ class TestPgDb(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () - connect_kw_args = { + connect_kw_args: Mapping[str, Any] = { 'database': dbname, 'host': f"{dbhost or ''}:{dbport or -1}", 'user': dbuser, 'password': dbpasswd} @@ -1323,7 +1324,7 @@ def test_no_close(self): data = ('hello', 'world') con = self._connect() cur = con.cursor() - cur.build_row_factory = lambda: tuple # noqa: E731 + cur.build_row_factory = lambda: tuple cur.execute("select %s, %s", data) row = cur.fetchone() self.assertEqual(row, data) diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index c4e8dd74..170c33c1 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -11,6 +11,7 @@ import unittest from collections.abc import Iterable +from typing import Sequence, Tuple import pgdb # the module under test @@ -154,9 +155,10 @@ def tearDown(self): except Exception: pass - data = [(1935, 'Luciano Pavarotti'), - (1941, 'Plácido Domingo'), - (1946, 'José Carreras')] + data: Sequence[Tuple[int, str]] = [ + (1935, 'Luciano Pavarotti'), + (1941, 'Plácido Domingo'), + (1946, 'José Carreras')] can_encode = True @@ -447,11 +449,11 @@ def test_null(self): self.cursor.execute('insert into copytest values(4, null)') try: ret = list(self.copy_to()) - self.assertEqual(ret, data + ['4\t\\N\n']) + self.assertEqual(ret, [*data, '4\t\\N\n']) ret = list(self.copy_to(null='Nix')) - self.assertEqual(ret, data + ['4\tNix\n']) + self.assertEqual(ret, [*data, '4\tNix\n']) ret = list(self.copy_to(null='')) - self.assertEqual(ret, data + ['4\t\n']) + self.assertEqual(ret, [*data, '4\t\n']) finally: self.cursor.execute('delete from copytest where id=4') From 33859e51483e3b0fece3c63213ec4b6e1c6bca11 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 09:42:38 +0200 Subject: [PATCH 143/194] Add ruff for local testing when provisioning --- .devcontainer/provision.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index a42337b8..2f3651d6 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -35,7 +35,9 @@ python3.9 -m pip install build python3.10 -m pip install build python3.11 -m pip install build -sudo apt-get install -y tox python3-poetry +pip install ruff + +sudo apt-get install -y tox # install PostgreSQL client tools From 950d7d8e22034e711aa79d2f49f5a41904ed74e8 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 10:21:42 +0200 Subject: [PATCH 144/194] Add testing with flake8-bugbear --- pg.py | 11 ++++----- pgdb.py | 36 ++++++++++++++++-------------- pyproject.toml | 2 ++ setup.py | 6 +++-- tests/test_classic.py | 6 ++--- tests/test_classic_connection.py | 4 ++-- tests/test_classic_dbwrapper.py | 24 ++++++++++---------- tests/test_classic_functions.py | 4 ++-- tests/test_classic_notification.py | 24 ++++++++++---------- tests/test_dbapi20.py | 2 +- tests/test_dbapi20_copy.py | 4 ++-- 11 files changed, 65 insertions(+), 58 deletions(-) diff --git a/pg.py b/pg.py index 25dc16e7..2df124b0 100644 --- a/pg.py +++ b/pg.py @@ -494,7 +494,7 @@ def _adapt_record(self, v, typ): raise TypeError(f'Record parameter {v} has wrong size') adapt = self.adapt value = [] - for v, t in zip(v, typ): + for v, t in zip(v, typ): # noqa: B020 v = adapt(v, t) if v is None: v = '' @@ -1989,7 +1989,7 @@ def pkey(self, table, composite=False, flush=False): self._do_debug('The pkey cache has been flushed') try: # cache lookup pkey = pkeys[table] - except KeyError: # cache miss, check the database + except KeyError as e: # cache miss, check the database q = ("SELECT a.attname, a.attnum, i.indkey" " FROM pg_catalog.pg_index i" " JOIN pg_catalog.pg_attribute a" @@ -2002,7 +2002,7 @@ def pkey(self, table, composite=False, flush=False): _quote_if_unqualified('$1', table)) pkey = self.db.query(q, (table,)).getresult() if not pkey: - raise KeyError(f'Table {table} has no primary key') + raise KeyError(f'Table {table} has no primary key') from e # we want to use the order defined in the primary key index here, # not the order as defined by the columns in the table if len(pkey) > 1: @@ -2173,12 +2173,13 @@ def get(self, table, row, keyname=None): if not keyname: try: # if keyname is not specified, try using the primary key keyname = self.pkey(table, True) - except KeyError: # the table has no primary key + except KeyError as e: # the table has no primary key # try using the oid instead if qoid and isinstance(row, dict) and 'oid' in row: keyname = ('oid',) else: - raise _prg_error(f'Table {table} has no primary key') + raise _prg_error( + f'Table {table} has no primary key') from e else: # the table has a primary key # check whether all key columns have values if isinstance(row, dict) and not set(keyname).issubset(row): diff --git a/pgdb.py b/pgdb.py index 00e57f02..74db29e9 100644 --- a/pgdb.py +++ b/pgdb.py @@ -993,7 +993,8 @@ def executemany(self, operation, seq_of_parameters): raise # database provides error message except Error as err: # noinspection PyTypeChecker - raise _db_error(f"Error in '{sql}': '{err}'", InterfaceError) + raise _db_error( + f"Error in '{sql}': '{err}'", InterfaceError) from err except Exception as err: raise _op_error(f"Internal error in '{sql}': {err}") from err # then initialize result raw count and description @@ -1090,9 +1091,10 @@ def copy_from(self, stream, table, binary_format = format == 'binary' try: read = stream.read - except AttributeError: + except AttributeError as e: if size: - raise ValueError("Size must only be set for file-like objects") + raise ValueError( + "Size must only be set for file-like objects") from e if binary_format: input_type = bytes type_name = 'byte strings' @@ -1102,7 +1104,7 @@ def copy_from(self, stream, table, if isinstance(stream, (bytes, str)): if not isinstance(stream, input_type): - raise ValueError(f"The input must be {type_name}") + raise ValueError(f"The input must be {type_name}") from e if not binary_format: if isinstance(stream, str): if not stream.endswith('\n'): @@ -1130,7 +1132,7 @@ def chunks(): yield chunk else: - raise TypeError("Need an input stream to copy from") + raise TypeError("Need an input stream to copy from") from e else: if size is None: size = 8192 @@ -1233,8 +1235,8 @@ def copy_to(self, stream, table, if stream is not None: try: write = stream.write - except AttributeError: - raise TypeError("Need an output stream to copy to") + except AttributeError as e: + raise TypeError("Need an output stream to copy to") from e if not table or not isinstance(table, str): raise TypeError("Need a table to copy to") if table.lower().startswith('select '): @@ -1405,8 +1407,8 @@ def __init__(self, cnx): self.autocommit = False try: self._cnx.source() - except Exception: - raise _op_error("Invalid connection") + except Exception as e: + raise _op_error("Invalid connection") from e def __enter__(self): """Enter the runtime context for the connection object. @@ -1420,8 +1422,8 @@ def __enter__(self): self._cnx.source().execute("BEGIN") except DatabaseError: raise # database provides error message - except Exception: - raise _op_error("Can't start transaction") + except Exception as e: + raise _op_error("Can't start transaction") from e else: self._tnx = True return self @@ -1466,8 +1468,8 @@ def commit(self): self._cnx.source().execute("COMMIT") except DatabaseError: raise # database provides error message - except Exception: - raise _op_error("Can't commit transaction") + except Exception as e: + raise _op_error("Can't commit transaction") from e else: raise _op_error("Connection has been closed") @@ -1480,8 +1482,8 @@ def rollback(self): self._cnx.source().execute("ROLLBACK") except DatabaseError: raise # database provides error message - except Exception: - raise _op_error("Can't rollback transaction") + except Exception as e: + raise _op_error("Can't rollback transaction") from e else: raise _op_error("Connection has been closed") @@ -1490,8 +1492,8 @@ def cursor(self): if self._cnx: try: return self.cursor_type(self) - except Exception: - raise _op_error("Invalid connection") + except Exception as e: + raise _op_error("Invalid connection") from e else: raise _op_error("Connection has been closed") diff --git a/pyproject.toml b/pyproject.toml index 382b09ca..2abfeb63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ download = "https://pygresql.github.io/download/" "mailing list" = "https://mail.vex.net/mailman/listinfo/pygresql" [tool.ruff] +target-version = "py37" line-length = 79 select = [ "E", # pycodestyle @@ -50,6 +51,7 @@ select = [ "UP", # pyupgrade "D", # pydocstyle "RUF", # ruff + "B", # bugbear ] exclude = [ "__pycache__", diff --git a/setup.py b/setup.py index a52f315d..6e4c5fd4 100755 --- a/setup.py +++ b/setup.py @@ -90,7 +90,8 @@ def initialize_options(self): supported = pg_version >= (10, 0) if not supported: warnings.warn( - "PyGreSQL does not support the installed PostgreSQL version.") + "PyGreSQL does not support the installed PostgreSQL version.", + stacklevel=2) def finalize_options(self): """Set final values for all build_pg options.""" @@ -104,7 +105,8 @@ def finalize_options(self): if not supported: warnings.warn( "The installed PostgreSQL version" - " does not support the memory size function.") + " does not support the memory size function.", + stacklevel=2) if sys.platform == 'win32': libraries[0] = 'lib' + libraries[0] if os.path.exists(os.path.join( diff --git a/tests/test_classic.py b/tests/test_classic.py index d6763074..18af07b2 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -215,7 +215,7 @@ def test_notify(self, options=None): thread.start() try: # Wait until the thread has started. - for n in range(500): + for _n in range(500): if target.listening: break sleep(0.01) @@ -237,7 +237,7 @@ def test_notify(self, options=None): if two_payloads: db2.commit() # Wait until the notification has been caught. - for n in range(500): + for _n in range(500): if arg_dict['called'] or self.notify_timeout: break sleep(0.01) @@ -256,7 +256,7 @@ def test_notify(self, options=None): db2.query("notify stop_event_1, 'payload 2'") db2.close() # Wait until the notification has been caught. - for n in range(500): + for _n in range(500): if arg_dict['called'] or self.notify_timeout: break sleep(0.01) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 440142c7..1fa9edb6 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1357,7 +1357,7 @@ def test_iterate(self): def test_iterate_twice(self): r = self.c.query("select generate_series(3,5)") - for i in range(2): + for _i in range(2): self.assertEqual(list(r), [(3,), (4,), (5,)]) def test_iterate_two_columns(self): @@ -2652,7 +2652,7 @@ def test_set_row_factory_size(self): query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): pg.set_row_factory_size(maxsize) - for i in range(3): + for _i in range(3): for q in queries: r = query(q).namedresult()[0] if q.endswith('abc'): diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 1f7b3aac..c563d932 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -2735,7 +2735,7 @@ def test_truncate(self): truncate('test_table') r = query(q).getresult()[0][0] self.assertEqual(r, 0) - for i in range(3): + for _i in range(3): query("insert into test_table values (1)") r = query(q).getresult()[0][0] self.assertEqual(r, 3) @@ -2744,7 +2744,7 @@ def test_truncate(self): self.assertEqual(r, 0) self.create_table('test_table_2', 'n smallint', temporary=True) for t in (list, tuple, set): - for i in range(3): + for _i in range(3): query("insert into test_table values (1)") query("insert into test_table_2 values (2)") q = ("select (select count(*) from test_table)," @@ -2760,7 +2760,7 @@ def test_truncate_restart(self): self.assertRaises(TypeError, truncate, 'test_table', restart='invalid') query = self.db.query self.create_table('test_table', 'n serial, t text') - for n in range(3): + for _n in range(3): query("insert into test_table (t) values ('test')") q = "select count(n), min(n), max(n) from test_table" r = query(q).getresult()[0] @@ -2768,14 +2768,14 @@ def test_truncate_restart(self): truncate('test_table') r = query(q).getresult()[0] self.assertEqual(r, (0, None, None)) - for n in range(3): + for _n in range(3): query("insert into test_table (t) values ('test')") r = query(q).getresult()[0] self.assertEqual(r, (3, 4, 6)) truncate('test_table', restart=True) r = query(q).getresult()[0] self.assertEqual(r, (0, None, None)) - for n in range(3): + for _n in range(3): query("insert into test_table (t) values ('test')") r = query(q).getresult()[0] self.assertEqual(r, (3, 1, 3)) @@ -2824,7 +2824,7 @@ def test_truncate_only(self): query = self.db.query self.create_table('test_parent', 'n smallint') self.create_table('test_child', 'm smallint) inherits (test_parent') - for n in range(3): + for _n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") q = ("select (select count(*) from test_parent)," @@ -2834,7 +2834,7 @@ def test_truncate_only(self): truncate('test_parent') r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - for n in range(3): + for _n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") r = query(q).getresult()[0] @@ -2842,7 +2842,7 @@ def test_truncate_only(self): truncate('test_parent*') r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - for n in range(3): + for _n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") r = query(q).getresult()[0] @@ -2859,7 +2859,7 @@ def test_truncate_only(self): self.create_table('test_child_2', 'm smallint) inherits (test_parent_2') for t in '', '_2': - for n in range(3): + for _n in range(3): query(f"insert into test_parent{t} (n) values (1)") query(f"insert into test_child{t} (n, m) values (2, 3)") q = ("select (select count(*) from test_parent)," @@ -2890,7 +2890,7 @@ def test_truncate_quoted(self): truncate(table) r = query(q).getresult()[0][0] self.assertEqual(r, 0) - for i in range(3): + for _i in range(3): query(f'insert into "{table}" values (1)') r = query(q).getresult()[0][0] self.assertEqual(r, 3) @@ -4703,11 +4703,11 @@ def setUpClass(cls): query(f"drop schema if exists {schema} cascade") try: query(f"create schema {schema}") - except pg.ProgrammingError: + except pg.ProgrammingError as e: raise RuntimeError( "The test user cannot create schemas.\n" f"Grant create on database {dbname} to the user" - " for running these tests.") + " for running these tests.") from e else: schema = "public" query(f"drop table if exists {schema}.t") diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 5babc816..37606b13 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -267,7 +267,7 @@ def test_parser_nested(self): self.assertEqual(len(r), 1) self.assertEqual(r[0], 'b') r = f('{{{{{{{abc}}}}}}}') - for i in range(7): + for _i in range(7): self.assertIsInstance(r, list) self.assertEqual(len(r), 1) # noinspection PyUnresolvedReferences @@ -282,7 +282,7 @@ def test_parser_too_deeply_nested(self): self.assertRaises(ValueError, f, r) else: r = f(r) - for i in range(n - 1): + for _i in range(n - 1): self.assertIsInstance(r, list) self.assertEqual(len(r), 1) r = r[0] diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 9e56bd6d..552f1ea5 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -208,7 +208,7 @@ def start_handler(self, event=None, arg_dict=None, thread.start() self.stopped = timeout == 0 self.addCleanup(self.stop_handler) - for n in range(500): + for _n in range(500): if handler.listening: break sleep(0.01) @@ -255,7 +255,7 @@ def notify_query(self, stop=False, payload=None): self.sent.append(arg_dict) def wait(self): - for n in range(500): + for _n in range(500): if self.timeout: return False if len(self.received) >= len(self.sent): @@ -309,15 +309,15 @@ def test_notify_with_args(self): def test_notify_several_times(self): arg_dict = {'test': 1} self.start_handler(arg_dict=arg_dict) - for count in range(3): + for _n in range(3): self.notify_query() self.receive() arg_dict['test'] += 1 - for count in range(2): + for _n in range(2): self.notify_handler() self.receive() arg_dict['test'] += 1 - for count in range(3): + for _n in range(3): self.notify_query() self.receive(stop=True) @@ -338,30 +338,30 @@ def test_notify_quoted_names(self): def test_notify_with_five_payloads(self): self.start_handler('gimme_5', {'test': 'Gimme 5'}) - for count in range(5): - self.notify_query(payload=f"Round {count}") + for n in range(5): + self.notify_query(payload=f"Round {n}") self.assertEqual(len(self.sent), 5) self.receive(stop=True) def test_receive_immediately(self): self.start_handler('immediate', {'test': 'immediate'}) - for count in range(3): - self.notify_query(payload=f"Round {count}") + for n in range(3): + self.notify_query(payload=f"Round {n}") self.receive() self.receive(stop=True) def test_notify_distinct_in_transaction(self): self.start_handler('test_transaction', {'transaction': True}) self.db.begin() - for count in range(3): - self.notify_query(payload=f'Round {count}') + for n in range(3): + self.notify_query(payload=f'Round {n}') self.db.commit() self.receive(stop=True) def test_notify_same_in_transaction(self): self.start_handler('test_transaction', {'transaction': True}) self.db.begin() - for count in range(3): + for _n in range(3): self.notify_query() self.db.commit() # these same notifications may be delivered as one, diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 9fd00165..657e820c 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1335,7 +1335,7 @@ def test_set_row_factory_size(self): cur = con.cursor() for maxsize in (None, 0, 1, 2, 3, 10, 1024): pgdb.set_row_factory_size(maxsize) - for i in range(3): + for _i in range(3): for q in queries: cur.execute(q) r = cur.fetchone() diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 170c33c1..ca775001 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -181,7 +181,7 @@ def table_data(self): def check_table(self): self.assertEqual(self.table_data, self.data) - def check_rowcount(self, number=len(data)): + def check_rowcount(self, number=len(data)): # noqa: B008 self.assertEqual(self.cursor.rowcount, number) @@ -429,7 +429,7 @@ def test_generator_bytes(self): def test_rowcount_increment(self): ret = self.copy_to() self.assertIsInstance(ret, Iterable) - for n, row in enumerate(ret): + for n, _row in enumerate(ret): self.check_rowcount(n + 1) def test_decode(self): From d8033beee03c56a50bd6982a83129a41402d38bf Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 10:54:07 +0200 Subject: [PATCH 145/194] Add testing with flake8-bandit --- pg.py | 19 ++++++++++++------- pgdb.py | 4 ++-- pyproject.toml | 5 +++-- setup.py | 2 +- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pg.py b/pg.py index 2df124b0..ac86d480 100644 --- a/pg.py +++ b/pg.py @@ -1990,7 +1990,8 @@ def pkey(self, table, composite=False, flush=False): try: # cache lookup pkey = pkeys[table] except KeyError as e: # cache miss, check the database - q = ("SELECT a.attname, a.attnum, i.indkey" + q = ("SELECT" # noqa: S608 + " a.attname, a.attnum, i.indkey" " FROM pg_catalog.pg_index i" " JOIN pg_catalog.pg_attribute a" " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" @@ -2038,7 +2039,8 @@ def get_relations(self, kinds=None, system=False): where.append("s.nspname NOT SIMILAR" " TO 'pg/_%|information/_schema' ESCAPE '/'") where = " WHERE " + ' AND '.join(where) if where else '' - q = ("SELECT pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" + q = ("SELECT" # noqa: S608 + " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" " FROM pg_catalog.pg_class r" " JOIN pg_catalog.pg_namespace s" @@ -2207,7 +2209,7 @@ def get(self, table, row, keyname=None): row[qoid] = row['oid'] del row['oid'] t = self._escape_qualified_name(table) - q = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' + q = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2259,7 +2261,8 @@ def insert(self, table, row=None, **kw): names, values = ', '.join(names), ', '.join(values) ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = f'INSERT INTO {t} ({names}) VALUES ({values}) RETURNING {ret}' + q = (f'INSERT INTO {t} ({names})' # noqa: S608 + f' VALUES ({values}) RETURNING {ret}') self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2322,7 +2325,8 @@ def update(self, table, row=None, **kw): values = ', '.join(values) ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = f'UPDATE {t} SET {values} WHERE {where} RETURNING {ret}' + q = (f'UPDATE {t} SET {values}' # noqa: S608 + f' WHERE {where} RETURNING {ret}') self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2417,7 +2421,8 @@ def upsert(self, table, row=None, **kw): do = 'update set ' + ', '.join(update) if update else 'nothing' ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = (f'INSERT INTO {t} AS included ({names}) VALUES ({values})' + q = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 + f' VALUES ({values})' f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') self._do_debug(q, params) q = self.db.query(q, params) @@ -2499,7 +2504,7 @@ def delete(self, table, row=None, **kw): row[qoid] = row['oid'] del row['oid'] t = self._escape_qualified_name(table) - q = f'DELETE FROM {t} WHERE {where}' + q = f'DELETE FROM {t} WHERE {where}' # noqa: S608 self._do_debug(q, params) res = self.db.query(q, params) return int(res) diff --git a/pgdb.py b/pgdb.py index 74db29e9..22f99498 100644 --- a/pgdb.py +++ b/pgdb.py @@ -686,7 +686,7 @@ def get_fields(self, typ): if not typ.relid: return None # this type is not composite self._src.execute( - "SELECT attname, atttypid" + "SELECT attname, atttypid" # noqa: S608 " FROM pg_catalog.pg_attribute" f" WHERE attrelid OPERATOR(pg_catalog.=) {typ.relid}" " AND attnum OPERATOR(pg_catalog.>) 0" @@ -1065,7 +1065,7 @@ def callproc(self, procname, parameters=None): """ n = len(parameters) if parameters else 0 s = ','.join(n * ['%s']) - query = f'select * from "{procname}"({s})' + query = f'select * from "{procname}"({s})' # noqa: S608 self.execute(query, parameters) return parameters diff --git a/pyproject.toml b/pyproject.toml index 2abfeb63..f5927d43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,8 +50,9 @@ select = [ "N", # pep8-naming "UP", # pyupgrade "D", # pydocstyle - "RUF", # ruff "B", # bugbear + "S", # bandit + "RUF", # ruff ] exclude = [ "__pycache__", @@ -69,7 +70,7 @@ exclude = [ ] [tool.ruff.per-file-ignores] -"tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107"] +"tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] [tool.setuptools] py-modules = ["pg", "pgdb"] diff --git a/setup.py b/setup.py index 6e4c5fd4..c20c9607 100755 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ def pg_config(s): """Retrieve information about installed version of PostgreSQL.""" - f = os.popen(f'pg_config --{s}') + f = os.popen(f'pg_config --{s}') # noqa: S605 d = f.readline().strip() if f.close() is not None: raise Exception("pg_config tool is not available.") From d844e8a6710d74671fafb976406014a3fbbd07c1 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 11:16:20 +0200 Subject: [PATCH 146/194] Add testing with flake8-simplify --- pg.py | 21 ++++++--------------- pgdb.py | 5 ++--- pyproject.toml | 1 + tests/config.py | 2 +- tests/dbapi20.py | 5 ++--- tests/test_classic.py | 17 +++++------------ tests/test_classic_connection.py | 28 +++++++++------------------- tests/test_classic_dbwrapper.py | 27 ++++++++------------------- tests/test_classic_largeobj.py | 21 ++++++++------------- tests/test_dbapi20_copy.py | 13 ++++--------- 10 files changed, 46 insertions(+), 94 deletions(-) diff --git a/pg.py b/pg.py index ac86d480..434dc906 100644 --- a/pg.py +++ b/pg.py @@ -23,6 +23,7 @@ import select import weakref from collections import OrderedDict, namedtuple +from contextlib import suppress from datetime import date, datetime, time, timedelta from decimal import Decimal from functools import lru_cache, partial @@ -1507,11 +1508,9 @@ def __init__(self, *args, **kw): if isinstance(db, DB): db = db.db else: - try: + with suppress(AttributeError): # noinspection PyUnresolvedReferences db = db._cnx - except AttributeError: - pass if not db or not hasattr(db, 'db') or not hasattr(db, 'query'): db = connect(*args, **kw) self._db_args = args, kw @@ -1592,15 +1591,11 @@ def __del__(self): except AttributeError: db = None if db: - try: + with suppress(TypeError): # when already closed db.set_cast_hook(None) - except TypeError: - pass # probably already closed if self._closeable: - try: + with suppress(InternalError): # when already closed db.close() - except InternalError: - pass # probably already closed # Auxiliary methods @@ -1661,10 +1656,8 @@ def close(self): # Wraps shared library function so we can track state. db = self.db if db: - try: + with suppress(TypeError): # when already closed db.set_cast_hook(None) - except TypeError: - pass # probably already closed if self._closeable: db.close() self.db = None @@ -2611,10 +2604,8 @@ def get_as_list(self, table, what=None, where=None, try: order = self.pkey(table, True) except (KeyError, ProgrammingError): - try: + with suppress(KeyError, ProgrammingError): order = list(self.get_attnames(table)) - except (KeyError, ProgrammingError): - pass if order: if isinstance(order, (list, tuple)): order = ', '.join(map(str, order)) diff --git a/pgdb.py b/pgdb.py index 22f99498..2e48e39d 100644 --- a/pgdb.py +++ b/pgdb.py @@ -66,6 +66,7 @@ from collections import namedtuple from collections.abc import Iterable +from contextlib import suppress from datetime import date, datetime, time, timedelta from decimal import Decimal as StdDecimal from functools import lru_cache, partial @@ -1442,10 +1443,8 @@ def close(self): """Close the connection object.""" if self._cnx: if self._tnx: - try: + with suppress(DatabaseError): self.rollback() - except DatabaseError: - pass self._cnx.close() self._cnx = None else: diff --git a/pyproject.toml b/pyproject.toml index f5927d43..131308b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ select = [ "D", # pydocstyle "B", # bugbear "S", # bandit + "SIM", # simplify "RUF", # ruff ] exclude = [ diff --git a/tests/config.py b/tests/config.py index acd8559a..f6280548 100644 --- a/tests/config.py +++ b/tests/config.py @@ -28,7 +28,7 @@ try: from .LOCAL_PyGreSQL import * # noqa: F403 except (ImportError, ValueError): - try: + try: # noqa: SIM105 from LOCAL_PyGreSQL import * # noqa: F403 except ImportError: pass diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 12a7647b..d5f2938f 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -11,6 +11,7 @@ import time import unittest +from contextlib import suppress from typing import Any, Mapping, Tuple @@ -181,11 +182,9 @@ def test_rollback(self): # If rollback is defined, it should either work or throw # the documented exception if hasattr(con, 'rollback'): - try: + with suppress(self.driver.NotSupportedError): # noinspection PyCallingNonCallable con.rollback() - except self.driver.NotSupportedError: - pass def test_cursor(self): con = self._connect() diff --git a/tests/test_classic.py b/tests/test_classic.py index 18af07b2..a6f78197 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -1,6 +1,7 @@ #!/usr/bin/python import unittest +from contextlib import suppress from functools import partial from threading import Thread from time import sleep @@ -34,27 +35,19 @@ class UtilityTest(unittest.TestCase): def setUpClass(cls): """Recreate test tables and schemas.""" db = open_db() - try: + with suppress(Exception): db.query("DROP VIEW _test_vschema") - except Exception: - pass - try: + with suppress(Exception): db.query("DROP TABLE _test_schema") - except Exception: - pass db.query("CREATE TABLE _test_schema" " (_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)") db.query("CREATE VIEW _test_vschema AS" " SELECT _test, 'abc'::text AS _test2 FROM _test_schema") for t in ('_test1', '_test2'): - try: + with suppress(Exception): db.query("CREATE SCHEMA " + t) - except Exception: - pass - try: + with suppress(Exception): db.query(f"DROP TABLE {t}._test_schema") - except Exception: - pass db.query(f"CREATE TABLE {t}._test_schema" f" ({t} int PRIMARY KEY)") db.close() diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 1fa9edb6..7d4409df 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -15,6 +15,7 @@ import unittest from collections import namedtuple from collections.abc import Iterable +from contextlib import suppress from decimal import Decimal from typing import Sequence, Tuple @@ -94,10 +95,8 @@ def setUp(self): self.connection = connect() def tearDown(self): - try: + with suppress(pg.InternalError): self.connection.close() - except pg.InternalError: - pass def is_method(self, attribute): """Check if given attribute on the connection is a method.""" @@ -152,10 +151,7 @@ def test_attribute_error(self): @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason) def test_attribute_host(self): - if dbhost and not dbhost.startswith('/'): - host = dbhost - else: - host = 'localhost' + host = dbhost if dbhost and not dbhost.startswith('/') else 'localhost' self.assertIsInstance(self.connection.host, str) self.assertEqual(self.connection.host, host) @@ -282,10 +278,8 @@ def test_all_query_members(self): self.assertEqual(members, query_members) def test_method_endcopy(self): - try: + with suppress(OSError): self.connection.endcopy() - except OSError: - pass def test_method_close(self): self.connection.close() @@ -1255,9 +1249,9 @@ def assert_proper_cast(self, value, pgtype, pytype): self.fail(str(e)) # noinspection PyUnboundLocalVariable self.assertIsInstance(r, pytype) - if isinstance(value, str): - if not value or ' ' in value or '{' in value: - value = f'"{value}"' + if isinstance(value, str) and ( + not value or ' ' in value or '{' in value): + value = f'"{value}"' value = f'{{{value}}}' r = self.c.query(q + '[]', (value,)).getresult()[0][0] if pgtype.startswith(('date', 'time', 'interval')): @@ -2194,10 +2188,8 @@ def test_getline(self): elif i == n: self.assertIsNone(v) finally: - try: + with suppress(OSError): self.c.endcopy() - except OSError: - pass def test_getline_bytes_and_unicode(self): getline = self.c.getline @@ -2218,10 +2210,8 @@ def test_getline_bytes_and_unicode(self): self.assertEqual(v, '73\twürstel') self.assertIsNone(getline()) finally: - try: + with suppress(OSError): self.c.endcopy() - except OSError: - pass def test_parameter_checks(self): self.assertRaises(TypeError, self.c.putline) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index c563d932..3884436f 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -16,6 +16,7 @@ import tempfile import unittest from collections import OrderedDict +from contextlib import suppress from datetime import date, datetime, time, timedelta from decimal import Decimal from io import StringIO @@ -167,10 +168,8 @@ def setUp(self): self.db = DB() def tearDown(self): - try: + with suppress(pg.InternalError): self.db.close() - except pg.InternalError: - pass def test_all_db_attributes(self): attributes = [ @@ -223,10 +222,7 @@ def test_attribute_error(self): @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason) def test_attribute_host(self): - if dbhost and not dbhost.startswith('/'): - host = dbhost - else: - host = 'localhost' + host = dbhost if dbhost and not dbhost.startswith('/') else 'localhost' self.assertIsInstance(self.db.host, str) self.assertEqual(self.db.host, host) self.assertEqual(self.db.db.host, host) @@ -334,10 +330,8 @@ def test_method_query_data_error(self): self.assertEqual(error.sqlstate, '22012') def test_method_endcopy(self): - try: + with suppress(OSError): self.db.endcopy() - except OSError: - pass def test_method_close(self): self.db.close() @@ -4352,10 +4346,8 @@ def setUp(self): self.adapter = self.db.adapter def tearDown(self): - try: + with suppress(pg.InternalError): self.db.close() - except pg.InternalError: - pass def test_guess_simple_type(self): f = self.adapter.guess_simple_type @@ -4744,12 +4736,9 @@ def tearDown(self): def test_get_tables(self): tables = self.db.get_tables() for num_schema in range(5): - if num_schema: - schema = "s" + str(num_schema) - else: - schema = "public" - for t in (schema + ".t", - schema + ".t" + str(num_schema)): + schema = 's' + str(num_schema) if num_schema else 'public' + for t in (schema + '.t', + schema + '.t' + str(num_schema)): self.assertIn(t, tables) def test_get_attnames(self): diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index afe48a21..7e5ad4a2 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from contextlib import suppress import pg # the module under test @@ -107,7 +108,7 @@ def test_lo_import(self): if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_import.txt' - f = open(fname, 'wb') + f = open(fname, 'wb') # noqa: SIM115 else: f = tempfile.NamedTemporaryFile() fname = f.name @@ -115,7 +116,7 @@ def test_lo_import(self): f.write(data) if windows: f.close() - f = open(fname, 'rb') + f = open(fname, 'rb') # noqa: SIM115 else: f.flush() f.seek(0) @@ -149,19 +150,13 @@ def setUp(self): def tearDown(self): if self.obj.oid: - try: + with suppress(SystemError, OSError): self.obj.close() - except (SystemError, OSError): - pass - try: + with suppress(SystemError, OSError): self.obj.unlink() - except (SystemError, OSError): - pass del self.obj - try: + with suppress(SystemError): self.pgcnx.query('rollback') - except SystemError: - pass self.pgcnx.close() def test_class_name(self): @@ -420,7 +415,7 @@ def test_export(self): if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_export.txt' - f = open(fname, 'wb') + f = open(fname, 'wb') # noqa: SIM115 else: f = tempfile.NamedTemporaryFile() fname = f.name @@ -433,7 +428,7 @@ def test_export(self): export(fname) if windows: f.close() - f = open(fname, 'rb') + f = open(fname, 'rb') # noqa: SIM115 r = f.read() f.close() if windows: diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index ca775001..bcacd476 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -11,6 +11,7 @@ import unittest from collections.abc import Iterable +from contextlib import suppress from typing import Sequence, Tuple import pgdb # the module under test @@ -142,18 +143,12 @@ def setUp(self): self.cursor.execute("set client_encoding=utf8") def tearDown(self): - try: + with suppress(Exception): self.cursor.close() - except Exception: - pass - try: + with suppress(Exception): self.con.rollback() - except Exception: - pass - try: + with suppress(Exception): self.con.close() - except Exception: - pass data: Sequence[Tuple[int, str]] = [ (1935, 'Luciano Pavarotti'), From fadd20762006f3c08f55abe916c4a19274fd2fb5 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 11:31:43 +0200 Subject: [PATCH 147/194] Add some type hints --- pg.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pg.py b/pg.py index 434dc906..d29cb5c2 100644 --- a/pg.py +++ b/pg.py @@ -34,7 +34,7 @@ from operator import itemgetter from re import compile as regex from types import MappingProxyType -from typing import ClassVar, Dict, List, Mapping, Type, Union +from typing import Callable, ClassVar, Dict, List, Mapping, Type, Union from uuid import UUID try: @@ -298,6 +298,8 @@ def _quote_if_unqualified(param, name): class _ParameterList(list): """Helper class for building typed parameter lists.""" + adapt: Callable + def add(self, value, typ=None): """Typecast value with known database type and build parameter list. @@ -1149,6 +1151,18 @@ class DbType(str): attnames: attributes for composite types """ + oid: int + pgtype: str + regtype: str + simple: str + typlen: int + typtype: str + category: str + delim: str + relid: int + + _get_attnames: Callable + @property def attnames(self): """Get names and types of the fields of a composite type.""" From 47f19c189ca0deb98dd165489de8b3c02ee7b02b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 15:06:31 +0200 Subject: [PATCH 148/194] Use clang-format for C files --- .clang-format | 25 ++ .devcontainer/provision.sh | 2 +- pgconn.c | 778 ++++++++++++++++++++----------------- pginternal.c | 527 +++++++++++++++---------- pglarge.c | 151 ++++--- pgmodule.c | 601 ++++++++++++++-------------- pgnotice.c | 73 ++-- pgquery.c | 435 +++++++++++---------- pgsource.c | 290 +++++++------- tox.ini | 9 +- 10 files changed, 1557 insertions(+), 1334 deletions(-) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..22f8603a --- /dev/null +++ b/.clang-format @@ -0,0 +1,25 @@ +# A clang-format style that approximates Python's PEP 7 +# Useful for IDE integration +# +# Based on Paul Ganssle's version at +# https://gist.github.com/pganssle/0e3a5f828b4d07d79447f6ced8e7e4db +BasedOnStyle: Google +AlwaysBreakAfterReturnType: All +AllowShortIfStatementsOnASingleLine: false +AlignAfterOpenBracket: Align +AlignTrailingComments: true +BreakBeforeBraces: Stroustrup +ColumnLimit: 79 +DerivePointerAlignment: false +IndentWidth: 4 +Language: Cpp +PointerAlignment: Right +ReflowComments: true +SpaceBeforeParens: ControlStatements +SpacesInParentheses: false +TabWidth: 4 +UseCRLF: false +UseTab: Never +StatementMacros: + - Py_BEGIN_ALLOW_THREADS + - Py_END_ALLOW_THREADS \ No newline at end of file diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 2f3651d6..c780e7df 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -37,7 +37,7 @@ python3.11 -m pip install build pip install ruff -sudo apt-get install -y tox +sudo apt-get install -y tox clang-format # install PostgreSQL client tools diff --git a/pgconn.c b/pgconn.c index c67e74dc..10e5b780 100644 --- a/pgconn.c +++ b/pgconn.c @@ -95,10 +95,12 @@ conn_getattr(connObject *self, PyObject *nameobj) /* whether the connection uses SSL */ if (!strcmp(name, "ssl_in_use")) { if (PQsslInUse(self->cnx)) { - Py_INCREF(Py_True); return Py_True; + Py_INCREF(Py_True); + return Py_True; } else { - Py_INCREF(Py_False); return Py_False; + Py_INCREF(Py_False); + return Py_False; } } @@ -107,7 +109,7 @@ conn_getattr(connObject *self, PyObject *nameobj) return get_ssl_attributes(self->cnx); } - return PyObject_GenericGetAttr((PyObject *) self, nameobj); + return PyObject_GenericGetAttr((PyObject *)self, nameobj); } /* Check connection validity. */ @@ -123,7 +125,7 @@ _check_cnx_obj(connObject *self) /* Create source object. */ static char conn_source__doc__[] = -"source() -- create a new source object for this connection"; + "source() -- create a new source object for this connection"; static PyObject * conn_source(connObject *self, PyObject *noargs) @@ -147,13 +149,13 @@ conn_source(connObject *self, PyObject *noargs) source_obj->valid = 1; source_obj->arraysize = PG_ARRAYSIZE; - return (PyObject *) source_obj; + return (PyObject *)source_obj; } /* For a non-query result, set the appropriate error status, return the appropriate value, and free the result set. */ static PyObject * -_conn_non_query_result(int status, PGresult* result, PGconn *cnx) +_conn_non_query_result(int status, PGresult *result, PGconn *cnx) { switch (status) { case PGRES_EMPTY_QUERY: @@ -162,29 +164,27 @@ _conn_non_query_result(int status, PGresult* result, PGconn *cnx) case PGRES_BAD_RESPONSE: case PGRES_FATAL_ERROR: case PGRES_NONFATAL_ERROR: - set_error(ProgrammingError, "Cannot execute query", - cnx, result); + set_error(ProgrammingError, "Cannot execute query", cnx, result); break; - case PGRES_COMMAND_OK: - { /* INSERT, UPDATE, DELETE */ - Oid oid = PQoidValue(result); + case PGRES_COMMAND_OK: { /* INSERT, UPDATE, DELETE */ + Oid oid = PQoidValue(result); - if (oid == InvalidOid) { /* not a single insert */ - char *ret = PQcmdTuples(result); + if (oid == InvalidOid) { /* not a single insert */ + char *ret = PQcmdTuples(result); - if (ret[0]) { /* return number of rows affected */ - PyObject *obj = PyUnicode_FromString(ret); - PQclear(result); - return obj; - } + if (ret[0]) { /* return number of rows affected */ + PyObject *obj = PyUnicode_FromString(ret); PQclear(result); - Py_INCREF(Py_None); - return Py_None; + return obj; } - /* for a single insert, return the oid */ PQclear(result); - return PyLong_FromLong((long) oid); + Py_INCREF(Py_None); + return Py_None; } + /* for a single insert, return the oid */ + PQclear(result); + return PyLong_FromLong((long)oid); + } case PGRES_COPY_OUT: /* no data will be received */ case PGRES_COPY_IN: PQclear(result); @@ -196,15 +196,15 @@ _conn_non_query_result(int status, PGresult* result, PGconn *cnx) PQclear(result); return NULL; /* error detected on query */ - } +} /* Base method for execution of all different kinds of queries */ static PyObject * _conn_query(connObject *self, PyObject *args, int prepared, int async) { PyObject *query_str_obj, *param_obj = NULL; - PGresult* result; - queryObject* query_obj; + PGresult *result; + queryObject *query_obj; char *query; int encoding, status, nparms = 0; @@ -226,7 +226,8 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) } else if (PyUnicode_Check(query_str_obj)) { query_str_obj = get_encoded_string(query_str_obj, encoding); - if (!query_str_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!query_str_obj) + return NULL; /* pass the UnicodeEncodeError */ query = PyBytes_AsString(query_str_obj); } else { @@ -246,7 +247,7 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) Py_XDECREF(query_str_obj); return NULL; } - nparms = (int) PySequence_Fast_GET_SIZE(param_obj); + nparms = (int)PySequence_Fast_GET_SIZE(param_obj); /* if there's a single argument and it's a list or tuple, it * contains the positional arguments. */ @@ -255,7 +256,7 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) if (PyList_Check(first_obj) || PyTuple_Check(first_obj)) { Py_DECREF(param_obj); param_obj = PySequence_Fast(first_obj, NULL); - nparms = (int) PySequence_Fast_GET_SIZE(param_obj); + nparms = (int)PySequence_Fast_GET_SIZE(param_obj); } } } @@ -267,11 +268,13 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) const char **parms, **p; register int i; - str = (PyObject **) PyMem_Malloc((size_t) nparms * sizeof(*str)); - parms = (const char **) PyMem_Malloc((size_t) nparms * sizeof(*parms)); + str = (PyObject **)PyMem_Malloc((size_t)nparms * sizeof(*str)); + parms = (const char **)PyMem_Malloc((size_t)nparms * sizeof(*parms)); if (!str || !parms) { - PyMem_Free((void *) parms); PyMem_Free(str); - Py_XDECREF(query_str_obj); Py_XDECREF(param_obj); + PyMem_Free((void *)parms); + PyMem_Free(str); + Py_XDECREF(query_str_obj); + Py_XDECREF(param_obj); return PyErr_NoMemory(); } @@ -290,8 +293,11 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) else if (PyUnicode_Check(obj)) { PyObject *str_obj = get_encoded_string(obj, encoding); if (!str_obj) { - PyMem_Free((void *) parms); - while (s != str) { s--; Py_DECREF(*s); } + PyMem_Free((void *)parms); + while (s != str) { + s--; + Py_DECREF(*s); + } PyMem_Free(str); Py_XDECREF(query_str_obj); Py_XDECREF(param_obj); @@ -304,8 +310,11 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) else { PyObject *str_obj = PyObject_Str(obj); if (!str_obj) { - PyMem_Free((void *) parms); - while (s != str) { s--; Py_DECREF(*s); } + PyMem_Free((void *)parms); + while (s != str) { + s--; + Py_DECREF(*s); + } PyMem_Free(str); Py_XDECREF(query_str_obj); Py_XDECREF(param_obj); @@ -321,22 +330,25 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) Py_BEGIN_ALLOW_THREADS if (async) { - status = PQsendQueryParams(self->cnx, query, nparms, - NULL, (const char * const *)parms, NULL, NULL, 0); + status = + PQsendQueryParams(self->cnx, query, nparms, NULL, + (const char *const *)parms, NULL, NULL, 0); result = NULL; } else { - result = prepared ? - PQexecPrepared(self->cnx, query, nparms, - parms, NULL, NULL, 0) : - PQexecParams(self->cnx, query, nparms, - NULL, parms, NULL, NULL, 0); + result = prepared ? PQexecPrepared(self->cnx, query, nparms, parms, + NULL, NULL, 0) + : PQexecParams(self->cnx, query, nparms, NULL, + parms, NULL, NULL, 0); status = result != NULL; } Py_END_ALLOW_THREADS - PyMem_Free((void *) parms); - while (s != str) { s--; Py_DECREF(*s); } + PyMem_Free((void *)parms); + while (s != str) { + s--; + Py_DECREF(*s); + } PyMem_Free(str); } else { @@ -346,10 +358,9 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) result = NULL; } else { - result = prepared ? - PQexecPrepared(self->cnx, query, 0, - NULL, NULL, NULL, 0) : - PQexec(self->cnx, query); + result = prepared ? PQexecPrepared(self->cnx, query, 0, NULL, NULL, + NULL, 0) + : PQexec(self->cnx, query); status = result != NULL; } Py_END_ALLOW_THREADS @@ -399,14 +410,14 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) } } - return (PyObject *) query_obj; + return (PyObject *)query_obj; } /* Database query */ static char conn_query__doc__[] = -"query(sql, [arg]) -- create a new query object for this connection\n\n" -"You must pass the SQL (string) request and you can optionally pass\n" -"a tuple with positional parameters.\n"; + "query(sql, [arg]) -- create a new query object for this connection\n\n" + "You must pass the SQL (string) request and you can optionally pass\n" + "a tuple with positional parameters.\n"; static PyObject * conn_query(connObject *self, PyObject *args) @@ -416,9 +427,10 @@ conn_query(connObject *self, PyObject *args) /* Asynchronous database query */ static char conn_send_query__doc__[] = -"send_query(sql, [arg]) -- create a new asynchronous query for this connection\n\n" -"You must pass the SQL (string) request and you can optionally pass\n" -"a tuple with positional parameters.\n"; + "send_query(sql, [arg]) -- create a new asynchronous query for this " + "connection\n\n" + "You must pass the SQL (string) request and you can optionally pass\n" + "a tuple with positional parameters.\n"; static PyObject * conn_send_query(connObject *self, PyObject *args) @@ -428,9 +440,9 @@ conn_send_query(connObject *self, PyObject *args) /* Execute prepared statement. */ static char conn_query_prepared__doc__[] = -"query_prepared(name, [arg]) -- execute a prepared statement\n\n" -"You must pass the name (string) of the prepared statement and you can\n" -"optionally pass a tuple with positional parameters.\n"; + "query_prepared(name, [arg]) -- execute a prepared statement\n\n" + "You must pass the name (string) of the prepared statement and you can\n" + "optionally pass a tuple with positional parameters.\n"; static PyObject * conn_query_prepared(connObject *self, PyObject *args) @@ -440,9 +452,9 @@ conn_query_prepared(connObject *self, PyObject *args) /* Create prepared statement. */ static char conn_prepare__doc__[] = -"prepare(name, sql) -- create a prepared statement\n\n" -"You must pass the name (string) of the prepared statement and the\n" -"SQL (string) request for later execution.\n"; + "prepare(name, sql) -- create a prepared statement\n\n" + "You must pass the name (string) of the prepared statement and the\n" + "SQL (string) request for later execution.\n"; static PyObject * conn_prepare(connObject *self, PyObject *args) @@ -457,9 +469,8 @@ conn_prepare(connObject *self, PyObject *args) } /* reads args */ - if (!PyArg_ParseTuple(args, "s#s#", - &name, &name_length, &query, &query_length)) - { + if (!PyArg_ParseTuple(args, "s#s#", &name, &name_length, &query, + &query_length)) { PyErr_SetString(PyExc_TypeError, "Method prepare() takes two string arguments"); return NULL; @@ -474,8 +485,8 @@ conn_prepare(connObject *self, PyObject *args) Py_INCREF(Py_None); return Py_None; /* success */ } - set_error(ProgrammingError, "Cannot create prepared statement", - self->cnx, result); + set_error(ProgrammingError, "Cannot create prepared statement", self->cnx, + result); if (result) PQclear(result); return NULL; /* error */ @@ -483,8 +494,8 @@ conn_prepare(connObject *self, PyObject *args) /* Describe prepared statement. */ static char conn_describe_prepared__doc__[] = -"describe_prepared(name) -- describe a prepared statement\n\n" -"You must pass the name (string) of the prepared statement.\n"; + "describe_prepared(name) -- describe a prepared statement\n\n" + "You must pass the name (string) of the prepared statement.\n"; static PyObject * conn_describe_prepared(connObject *self, PyObject *args) @@ -521,17 +532,17 @@ conn_describe_prepared(connObject *self, PyObject *args) query_obj->max_row = PQntuples(result); query_obj->num_fields = PQnfields(result); query_obj->col_types = get_col_types(result, query_obj->num_fields); - return (PyObject *) query_obj; + return (PyObject *)query_obj; } set_error(ProgrammingError, "Cannot describe prepared statement", - self->cnx, result); + self->cnx, result); if (result) PQclear(result); return NULL; /* error */ } static char conn_putline__doc__[] = -"putline(line) -- send a line directly to the backend"; + "putline(line) -- send a line directly to the backend"; /* Direct access function: putline. */ static PyObject * @@ -554,12 +565,14 @@ conn_putline(connObject *self, PyObject *args) } /* send line to backend */ - ret = PQputCopyData(self->cnx, line, (int) line_length); + ret = PQputCopyData(self->cnx, line, (int)line_length); if (ret != 1) { - PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) : - "Line cannot be queued, wait for write-ready and try again"); + PyErr_SetString( + PyExc_IOError, + ret == -1 + ? PQerrorMessage(self->cnx) + : "Line cannot be queued, wait for write-ready and try again"); return NULL; - } Py_INCREF(Py_None); return Py_None; @@ -567,7 +580,7 @@ conn_putline(connObject *self, PyObject *args) /* Direct access function: getline. */ static char conn_getline__doc__[] = -"getline() -- get a line directly from the backend"; + "getline() -- get a line directly from the backend"; static PyObject * conn_getline(connObject *self, PyObject *noargs) @@ -586,15 +599,18 @@ conn_getline(connObject *self, PyObject *noargs) /* check result */ if (ret <= 0) { - if (line != NULL) PQfreemem(line); + if (line != NULL) + PQfreemem(line); if (ret == -1) { PQgetResult(self->cnx); Py_INCREF(Py_None); return Py_None; } - PyErr_SetString(PyExc_MemoryError, - ret == -2 ? PQerrorMessage(self->cnx) : - "No line available, wait for read-ready and try again"); + PyErr_SetString( + PyExc_MemoryError, + ret == -2 + ? PQerrorMessage(self->cnx) + : "No line available, wait for read-ready and try again"); return NULL; } if (line == NULL) { @@ -602,7 +618,8 @@ conn_getline(connObject *self, PyObject *noargs) return Py_None; } /* for backward compatibility, convert terminating newline to zero byte */ - if (*line) line[strlen(line) - 1] = '\0'; + if (*line) + line[strlen(line) - 1] = '\0'; str = PyUnicode_FromString(line); PQfreemem(line); return str; @@ -610,7 +627,7 @@ conn_getline(connObject *self, PyObject *noargs) /* Direct access function: end copy. */ static char conn_endcopy__doc__[] = -"endcopy() -- synchronize client and server"; + "endcopy() -- synchronize client and server"; static PyObject * conn_endcopy(connObject *self, PyObject *noargs) @@ -624,11 +641,11 @@ conn_endcopy(connObject *self, PyObject *noargs) /* end direct copy */ ret = PQputCopyEnd(self->cnx, NULL); - if (ret != 1) - { - PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) : - "Termination message cannot be queued," - " wait for write-ready and try again"); + if (ret != 1) { + PyErr_SetString(PyExc_IOError, + ret == -1 ? PQerrorMessage(self->cnx) + : "Termination message cannot be queued," + " wait for write-ready and try again"); return NULL; } Py_INCREF(Py_None); @@ -637,7 +654,7 @@ conn_endcopy(connObject *self, PyObject *noargs) /* Direct access function: set blocking status. */ static char conn_set_non_blocking__doc__[] = -"set_non_blocking() -- set the non-blocking status of the connection"; + "set_non_blocking() -- set the non-blocking status of the connection"; static PyObject * conn_set_non_blocking(connObject *self, PyObject *args) @@ -666,7 +683,7 @@ conn_set_non_blocking(connObject *self, PyObject *args) /* Direct access function: get blocking status. */ static char conn_is_non_blocking__doc__[] = -"is_non_blocking() -- report the blocking status of the connection"; + "is_non_blocking() -- report the blocking status of the connection"; static PyObject * conn_is_non_blocking(connObject *self, PyObject *noargs) @@ -687,12 +704,11 @@ conn_is_non_blocking(connObject *self, PyObject *noargs) return PyBool_FromLong((long)rc); } - /* Insert table */ static char conn_inserttable__doc__[] = -"inserttable(table, data, [columns]) -- insert iterable into table\n\n" -"The fields in the iterable must be in the same order as in the table\n" -"or in the list or tuple of columns if one is specified.\n"; + "inserttable(table, data, [columns]) -- insert iterable into table\n\n" + "The fields in the iterable must be in the same order as in the table\n" + "or in the list or tuple of columns if one is specified.\n"; static PyObject * conn_inserttable(connObject *self, PyObject *args) @@ -718,8 +734,7 @@ conn_inserttable(connObject *self, PyObject *args) } /* checks list type */ - if (!(iter_row = PyObject_GetIter(rows))) - { + if (!(iter_row = PyObject_GetIter(rows))) { PyErr_SetString( PyExc_TypeError, "Method inserttable() expects an iterable as second argument"); @@ -728,31 +743,36 @@ conn_inserttable(connObject *self, PyObject *args) m = PySequence_Check(rows) ? PySequence_Size(rows) : -1; if (!m) { /* no rows specified, nothing to do */ - Py_DECREF(iter_row); Py_INCREF(Py_None); return Py_None; + Py_DECREF(iter_row); + Py_INCREF(Py_None); + return Py_None; } /* checks columns type */ if (columns) { if (!(PyTuple_Check(columns) || PyList_Check(columns))) { - PyErr_SetString( - PyExc_TypeError, - "Method inserttable() expects a tuple or a list" - " as third argument"); + PyErr_SetString(PyExc_TypeError, + "Method inserttable() expects a tuple or a list" + " as third argument"); return NULL; } n = PySequence_Fast_GET_SIZE(columns); if (!n) { /* no columns specified, nothing to do */ - Py_DECREF(iter_row); Py_INCREF(Py_None); return Py_None; + Py_DECREF(iter_row); + Py_INCREF(Py_None); + return Py_None; } - } else { + } + else { n = -1; /* number of columns not yet known */ } /* allocate buffer */ if (!(buffer = PyMem_Malloc(MAX_BUFFER_SIZE))) { - Py_DECREF(iter_row); return PyErr_NoMemory(); + Py_DECREF(iter_row); + return PyErr_NoMemory(); } encoding = PQclientEncoding(self->cnx); @@ -760,22 +780,26 @@ conn_inserttable(connObject *self, PyObject *args) /* starts query */ bufpt = buffer; bufmax = bufpt + MAX_BUFFER_SIZE; - bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "copy "); + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "copy "); s = table; do { - t = strchr(s, '.'); if (!t) t = s + strlen(s); - table = PQescapeIdentifier(self->cnx, s, (size_t) (t - s)); + t = strchr(s, '.'); + if (!t) + t = s + strlen(s); + table = PQescapeIdentifier(self->cnx, s, (size_t)(t - s)); if (bufpt < bufmax) - bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "%s", table); + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "%s", table); PQfreemem(table); - s = t; if (*s && bufpt < bufmax) *bufpt++ = *s++; + s = t; + if (*s && bufpt < bufmax) + *bufpt++ = *s++; } while (*s); if (columns) { /* adds a string like f" ({','.join(columns)})" */ if (bufpt < bufmax) - bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), " ("); + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), " ("); for (j = 0; j < n; ++j) { PyObject *obj = PySequence_Fast_GET_ITEM(columns, j); Py_ssize_t slen; @@ -787,29 +811,33 @@ conn_inserttable(connObject *self, PyObject *args) else if (PyUnicode_Check(obj)) { obj = get_encoded_string(obj, encoding); if (!obj) { - PyMem_Free(buffer); Py_DECREF(iter_row); + PyMem_Free(buffer); + Py_DECREF(iter_row); return NULL; /* pass the UnicodeEncodeError */ } - } else { + } + else { PyErr_SetString( PyExc_TypeError, "The third argument must contain only strings"); - PyMem_Free(buffer); Py_DECREF(iter_row); + PyMem_Free(buffer); + Py_DECREF(iter_row); return NULL; } PyBytes_AsStringAndSize(obj, &col, &slen); - col = PQescapeIdentifier(self->cnx, col, (size_t) slen); + col = PQescapeIdentifier(self->cnx, col, (size_t)slen); Py_DECREF(obj); if (bufpt < bufmax) - bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), - "%s%s", col, j == n - 1 ? ")" : ","); + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "%s%s", col, + j == n - 1 ? ")" : ","); PQfreemem(col); } } if (bufpt < bufmax) - snprintf(bufpt, (size_t) (bufmax - bufpt), " from stdin"); - if (bufpt >= bufmax) { - PyMem_Free(buffer); Py_DECREF(iter_row); + snprintf(bufpt, (size_t)(bufmax - bufpt), " from stdin"); + if (bufpt >= bufmax) { + PyMem_Free(buffer); + Py_DECREF(iter_row); return PyErr_NoMemory(); } @@ -818,7 +846,8 @@ conn_inserttable(connObject *self, PyObject *args) Py_END_ALLOW_THREADS if (!result || PQresultStatus(result) != PGRES_COPY_IN) { - PyMem_Free(buffer); Py_DECREF(iter_row); + PyMem_Free(buffer); + Py_DECREF(iter_row); PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); return NULL; } @@ -827,12 +856,15 @@ conn_inserttable(connObject *self, PyObject *args) /* feed table */ for (i = 0; m < 0 || i < m; ++i) { - - if (!(columns = PyIter_Next(iter_row))) break; + if (!(columns = PyIter_Next(iter_row))) + break; if (!(PyTuple_Check(columns) || PyList_Check(columns))) { - PQputCopyEnd(self->cnx, "Invalid arguments"); PyMem_Free(buffer); - Py_DECREF(columns); Py_DECREF(columns); Py_DECREF(iter_row); + PQputCopyEnd(self->cnx, "Invalid arguments"); + PyMem_Free(buffer); + Py_DECREF(columns); + Py_DECREF(columns); + Py_DECREF(iter_row); PyErr_SetString( PyExc_TypeError, "The second argument must contain tuples or lists"); @@ -842,9 +874,12 @@ conn_inserttable(connObject *self, PyObject *args) j = PySequence_Fast_GET_SIZE(columns); if (n < 0) { n = j; - } else if (j != n) { - PQputCopyEnd(self->cnx, "Invalid arguments"); PyMem_Free(buffer); - Py_DECREF(columns); Py_DECREF(iter_row); + } + else if (j != n) { + PQputCopyEnd(self->cnx, "Invalid arguments"); + PyMem_Free(buffer); + Py_DECREF(columns); + Py_DECREF(iter_row); PyErr_SetString( PyExc_TypeError, "The second arg must contain sequences of the same size"); @@ -857,7 +892,8 @@ conn_inserttable(connObject *self, PyObject *args) for (j = 0; j < n; ++j) { if (j) { - *bufpt++ = '\t'; --bufsiz; + *bufpt++ = '\t'; + --bufsiz; } item = PySequence_Fast_GET_ITEM(columns, j); @@ -865,37 +901,43 @@ conn_inserttable(connObject *self, PyObject *args) /* convert item to string and append to buffer */ if (item == Py_None) { if (bufsiz > 2) { - *bufpt++ = '\\'; *bufpt++ = 'N'; + *bufpt++ = '\\'; + *bufpt++ = 'N'; bufsiz -= 2; } else bufsiz = 0; } else if (PyBytes_Check(item)) { - const char* t = PyBytes_AsString(item); + const char *t = PyBytes_AsString(item); while (*t && bufsiz) { switch (*t) { case '\\': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= '\\'; + if (--bufsiz) + *bufpt++ = '\\'; break; case '\t': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 't'; + if (--bufsiz) + *bufpt++ = 't'; break; case '\r': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'r'; + if (--bufsiz) + *bufpt++ = 'r'; break; case '\n': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'n'; + if (--bufsiz) + *bufpt++ = 'n'; break; default: - *bufpt ++= *t; + *bufpt++ = *t; } - ++t; --bufsiz; + ++t; + --bufsiz; } } else if (PyUnicode_Check(item)) { @@ -903,83 +945,97 @@ conn_inserttable(connObject *self, PyObject *args) if (!s) { PQputCopyEnd(self->cnx, "Encoding error"); PyMem_Free(buffer); - Py_DECREF(item); Py_DECREF(columns); Py_DECREF(iter_row); + Py_DECREF(item); + Py_DECREF(columns); + Py_DECREF(iter_row); return NULL; /* pass the UnicodeEncodeError */ } else { - const char* t = PyBytes_AsString(s); + const char *t = PyBytes_AsString(s); while (*t && bufsiz) { switch (*t) { case '\\': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= '\\'; + if (--bufsiz) + *bufpt++ = '\\'; break; case '\t': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 't'; + if (--bufsiz) + *bufpt++ = 't'; break; case '\r': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'r'; + if (--bufsiz) + *bufpt++ = 'r'; break; case '\n': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'n'; + if (--bufsiz) + *bufpt++ = 'n'; break; default: - *bufpt ++= *t; + *bufpt++ = *t; } - ++t; --bufsiz; + ++t; + --bufsiz; } Py_DECREF(s); } } else if (PyLong_Check(item)) { - PyObject* s = PyObject_Str(item); - const char* t = PyUnicode_AsUTF8(s); + PyObject *s = PyObject_Str(item); + const char *t = PyUnicode_AsUTF8(s); while (*t && bufsiz) { - *bufpt++ = *t++; --bufsiz; + *bufpt++ = *t++; + --bufsiz; } Py_DECREF(s); } else { - PyObject* s = PyObject_Repr(item); - const char* t = PyUnicode_AsUTF8(s); + PyObject *s = PyObject_Repr(item); + const char *t = PyUnicode_AsUTF8(s); while (*t && bufsiz) { switch (*t) { case '\\': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= '\\'; + if (--bufsiz) + *bufpt++ = '\\'; break; case '\t': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 't'; + if (--bufsiz) + *bufpt++ = 't'; break; case '\r': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'r'; + if (--bufsiz) + *bufpt++ = 'r'; break; case '\n': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'n'; + if (--bufsiz) + *bufpt++ = 'n'; break; default: - *bufpt ++= *t; + *bufpt++ = *t; } - ++t; --bufsiz; + ++t; + --bufsiz; } Py_DECREF(s); } if (bufsiz <= 0) { - PQputCopyEnd(self->cnx, "Memory error"); PyMem_Free(buffer); - Py_DECREF(columns); Py_DECREF(iter_row); + PQputCopyEnd(self->cnx, "Memory error"); + PyMem_Free(buffer); + Py_DECREF(columns); + Py_DECREF(iter_row); return PyErr_NoMemory(); } - } Py_DECREF(columns); @@ -987,13 +1043,14 @@ conn_inserttable(connObject *self, PyObject *args) *bufpt++ = '\n'; /* sends data */ - ret = PQputCopyData(self->cnx, buffer, (int) (bufpt - buffer)); + ret = PQputCopyData(self->cnx, buffer, (int)(bufpt - buffer)); if (ret != 1) { - char *errormsg = ret == - 1 ? - PQerrorMessage(self->cnx) : "Data cannot be queued"; + char *errormsg = ret == -1 ? PQerrorMessage(self->cnx) + : "Data cannot be queued"; PyErr_SetString(PyExc_IOError, errormsg); PQputCopyEnd(self->cnx, errormsg); - PyMem_Free(buffer); Py_DECREF(iter_row); + PyMem_Free(buffer); + Py_DECREF(iter_row); return NULL; } } @@ -1006,8 +1063,8 @@ conn_inserttable(connObject *self, PyObject *args) ret = PQputCopyEnd(self->cnx, NULL); if (ret != 1) { - PyErr_SetString(PyExc_IOError, ret == -1 ? - PQerrorMessage(self->cnx) : "Data cannot be queued"); + PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) + : "Data cannot be queued"); PyMem_Free(buffer); return NULL; } @@ -1021,7 +1078,8 @@ conn_inserttable(connObject *self, PyObject *args) PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); PQclear(result); return NULL; - } else { + } + else { long ntuples = atol(PQcmdTuples(result)); PQclear(result); return PyLong_FromLong(ntuples); @@ -1030,7 +1088,7 @@ conn_inserttable(connObject *self, PyObject *args) /* Get transaction state. */ static char conn_transaction__doc__[] = -"transaction() -- return the current transaction status"; + "transaction() -- return the current transaction status"; static PyObject * conn_transaction(connObject *self, PyObject *noargs) @@ -1045,7 +1103,7 @@ conn_transaction(connObject *self, PyObject *noargs) /* Get parameter setting. */ static char conn_parameter__doc__[] = -"parameter(name) -- look up a current parameter setting"; + "parameter(name) -- look up a current parameter setting"; static PyObject * conn_parameter(connObject *self, PyObject *args) @@ -1076,7 +1134,7 @@ conn_parameter(connObject *self, PyObject *args) /* Get current date format. */ static char conn_date_format__doc__[] = -"date_format() -- return the current date format"; + "date_format() -- return the current date format"; static PyObject * conn_date_format(connObject *self, PyObject *noargs) @@ -1100,18 +1158,18 @@ conn_date_format(connObject *self, PyObject *noargs) /* Escape literal */ static char conn_escape_literal__doc__[] = -"escape_literal(str) -- escape a literal constant for use within SQL"; + "escape_literal(str) -- escape a literal constant for use within SQL"; static PyObject * conn_escape_literal(connObject *self, PyObject *string) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(string)) { PyBytes_AsStringAndSize(string, &from, &from_length); @@ -1119,7 +1177,8 @@ conn_escape_literal(connObject *self, PyObject *string) else if (PyUnicode_Check(string)) { encoding = PQclientEncoding(self->cnx); tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -1129,15 +1188,15 @@ conn_escape_literal(connObject *self, PyObject *string) return NULL; } - to = PQescapeLiteral(self->cnx, from, (size_t) from_length); + to = PQescapeLiteral(self->cnx, from, (size_t)from_length); to_length = strlen(to); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); if (to) PQfreemem(to); return to_obj; @@ -1145,18 +1204,18 @@ conn_escape_literal(connObject *self, PyObject *string) /* Escape identifier */ static char conn_escape_identifier__doc__[] = -"escape_identifier(str) -- escape an identifier for use within SQL"; + "escape_identifier(str) -- escape an identifier for use within SQL"; static PyObject * conn_escape_identifier(connObject *self, PyObject *string) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(string)) { PyBytes_AsStringAndSize(string, &from, &from_length); @@ -1164,7 +1223,8 @@ conn_escape_identifier(connObject *self, PyObject *string) else if (PyUnicode_Check(string)) { encoding = PQclientEncoding(self->cnx); tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -1174,15 +1234,15 @@ conn_escape_identifier(connObject *self, PyObject *string) return NULL; } - to = PQescapeIdentifier(self->cnx, from, (size_t) from_length); + to = PQescapeIdentifier(self->cnx, from, (size_t)from_length); to_length = strlen(to); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); if (to) PQfreemem(to); return to_obj; @@ -1190,18 +1250,18 @@ conn_escape_identifier(connObject *self, PyObject *string) /* Escape string */ static char conn_escape_string__doc__[] = -"escape_string(str) -- escape a string for use within SQL"; + "escape_string(str) -- escape a string for use within SQL"; static PyObject * conn_escape_string(connObject *self, PyObject *string) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(string)) { PyBytes_AsStringAndSize(string, &from, &from_length); @@ -1209,49 +1269,50 @@ conn_escape_string(connObject *self, PyObject *string) else if (PyUnicode_Check(string)) { encoding = PQclientEncoding(self->cnx); tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { - PyErr_SetString( - PyExc_TypeError, - "Method escape_string() expects a string as argument"); + PyErr_SetString(PyExc_TypeError, + "Method escape_string() expects a string as argument"); return NULL; } - to_length = 2 * (size_t) from_length + 1; - if ((Py_ssize_t) to_length < from_length) { /* overflow */ - to_length = (size_t) from_length; - from_length = (from_length - 1)/2; + to_length = 2 * (size_t)from_length + 1; + if ((Py_ssize_t)to_length < from_length) { /* overflow */ + to_length = (size_t)from_length; + from_length = (from_length - 1) / 2; } - to = (char *) PyMem_Malloc(to_length); - to_length = PQescapeStringConn(self->cnx, - to, from, (size_t) from_length, NULL); + to = (char *)PyMem_Malloc(to_length); + to_length = + PQescapeStringConn(self->cnx, to, from, (size_t)from_length, NULL); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); PyMem_Free(to); return to_obj; } /* Escape bytea */ static char conn_escape_bytea__doc__[] = -"escape_bytea(data) -- escape binary data for use within SQL as type bytea"; + "escape_bytea(data) -- escape binary data for use within SQL as type " + "bytea"; static PyObject * conn_escape_bytea(connObject *self, PyObject *data) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(data)) { PyBytes_AsStringAndSize(data, &from, &from_length); @@ -1259,25 +1320,25 @@ conn_escape_bytea(connObject *self, PyObject *data) else if (PyUnicode_Check(data)) { encoding = PQclientEncoding(self->cnx); tmp_obj = get_encoded_string(data, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { - PyErr_SetString( - PyExc_TypeError, - "Method escape_bytea() expects a string as argument"); + PyErr_SetString(PyExc_TypeError, + "Method escape_bytea() expects a string as argument"); return NULL; } - to = (char *) PQescapeByteaConn(self->cnx, - (unsigned char *) from, (size_t) from_length, &to_length); + to = (char *)PQescapeByteaConn(self->cnx, (unsigned char *)from, + (size_t)from_length, &to_length); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length - 1); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length - 1); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length - 1, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length - 1, encoding); if (to) PQfreemem(to); return to_obj; @@ -1303,7 +1364,7 @@ large_new(connObject *pgcnx, Oid oid) /* Create large object. */ static char conn_locreate__doc__[] = -"locreate(mode) -- create a new large object in the database"; + "locreate(mode) -- create a new large object in the database"; static PyObject * conn_locreate(connObject *self, PyObject *args) @@ -1330,12 +1391,12 @@ conn_locreate(connObject *self, PyObject *args) return NULL; } - return (PyObject *) large_new(self, lo_oid); + return (PyObject *)large_new(self, lo_oid); } /* Init from already known oid. */ static char conn_getlo__doc__[] = -"getlo(oid) -- create a large object instance for the specified oid"; + "getlo(oid) -- create a large object instance for the specified oid"; static PyObject * conn_getlo(connObject *self, PyObject *args) @@ -1355,19 +1416,19 @@ conn_getlo(connObject *self, PyObject *args) return NULL; } - lo_oid = (Oid) oid; + lo_oid = (Oid)oid; if (lo_oid == 0) { PyErr_SetString(PyExc_ValueError, "The object oid can't be null"); return NULL; } /* creates object */ - return (PyObject *) large_new(self, lo_oid); + return (PyObject *)large_new(self, lo_oid); } /* Import unix file. */ static char conn_loimport__doc__[] = -"loimport(name) -- create a new large object from specified file"; + "loimport(name) -- create a new large object from specified file"; static PyObject * conn_loimport(connObject *self, PyObject *args) @@ -1394,14 +1455,14 @@ conn_loimport(connObject *self, PyObject *args) return NULL; } - return (PyObject *) large_new(self, lo_oid); + return (PyObject *)large_new(self, lo_oid); } /* Reset connection. */ static char conn_reset__doc__[] = -"reset() -- reset connection with current parameters\n\n" -"All derived queries and large objects derived from this connection\n" -"will not be usable after this call.\n"; + "reset() -- reset connection with current parameters\n\n" + "All derived queries and large objects derived from this connection\n" + "will not be usable after this call.\n"; static PyObject * conn_reset(connObject *self, PyObject *noargs) @@ -1419,7 +1480,7 @@ conn_reset(connObject *self, PyObject *noargs) /* Cancel current command. */ static char conn_cancel__doc__[] = -"cancel() -- abandon processing of the current command"; + "cancel() -- abandon processing of the current command"; static PyObject * conn_cancel(connObject *self, PyObject *noargs) @@ -1430,12 +1491,12 @@ conn_cancel(connObject *self, PyObject *noargs) } /* request that the server abandon processing of the current command */ - return PyLong_FromLong((long) PQrequestCancel(self->cnx)); + return PyLong_FromLong((long)PQrequestCancel(self->cnx)); } /* Get connection socket. */ static char conn_fileno__doc__[] = -"fileno() -- return database connection socket file handle"; + "fileno() -- return database connection socket file handle"; static PyObject * conn_fileno(connObject *self, PyObject *noargs) @@ -1445,12 +1506,12 @@ conn_fileno(connObject *self, PyObject *noargs) return NULL; } - return PyLong_FromLong((long) PQsocket(self->cnx)); + return PyLong_FromLong((long)PQsocket(self->cnx)); } /* Set external typecast callback function. */ static char conn_set_cast_hook__doc__[] = -"set_cast_hook(func) -- set a fallback typecast function"; + "set_cast_hook(func) -- set a fallback typecast function"; static PyObject * conn_set_cast_hook(connObject *self, PyObject *func) @@ -1460,12 +1521,15 @@ conn_set_cast_hook(connObject *self, PyObject *func) if (func == Py_None) { Py_XDECREF(self->cast_hook); self->cast_hook = NULL; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else if (PyCallable_Check(func)) { - Py_XINCREF(func); Py_XDECREF(self->cast_hook); + Py_XINCREF(func); + Py_XDECREF(self->cast_hook); self->cast_hook = func; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -1478,12 +1542,13 @@ conn_set_cast_hook(connObject *self, PyObject *func) /* Get notice receiver callback function. */ static char conn_get_cast_hook__doc__[] = -"get_cast_hook() -- get the fallback typecast function"; + "get_cast_hook() -- get the fallback typecast function"; static PyObject * conn_get_cast_hook(connObject *self, PyObject *noargs) { - PyObject *ret = self->cast_hook;; + PyObject *ret = self->cast_hook; + ; if (!ret) ret = Py_None; @@ -1494,7 +1559,7 @@ conn_get_cast_hook(connObject *self, PyObject *noargs) /* Get asynchronous connection state. */ static char conn_poll__doc__[] = -"poll() -- Completes an asynchronous connection"; + "poll() -- Completes an asynchronous connection"; static PyObject * conn_poll(connObject *self, PyObject *noargs) @@ -1521,7 +1586,7 @@ conn_poll(connObject *self, PyObject *noargs) /* Set notice receiver callback function. */ static char conn_set_notice_receiver__doc__[] = -"set_notice_receiver(func) -- set the current notice receiver"; + "set_notice_receiver(func) -- set the current notice receiver"; static PyObject * conn_set_notice_receiver(connObject *self, PyObject *func) @@ -1531,13 +1596,16 @@ conn_set_notice_receiver(connObject *self, PyObject *func) if (func == Py_None) { Py_XDECREF(self->notice_receiver); self->notice_receiver = NULL; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else if (PyCallable_Check(func)) { - Py_XINCREF(func); Py_XDECREF(self->notice_receiver); + Py_XINCREF(func); + Py_XDECREF(self->notice_receiver); self->notice_receiver = func; PQsetNoticeReceiver(self->cnx, notice_receiver, self); - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -1550,7 +1618,7 @@ conn_set_notice_receiver(connObject *self, PyObject *func) /* Get notice receiver callback function. */ static char conn_get_notice_receiver__doc__[] = -"get_notice_receiver() -- get the current notice receiver"; + "get_notice_receiver() -- get the current notice receiver"; static PyObject * conn_get_notice_receiver(connObject *self, PyObject *noargs) @@ -1566,9 +1634,9 @@ conn_get_notice_receiver(connObject *self, PyObject *noargs) /* Close without deleting. */ static char conn_close__doc__[] = -"close() -- close connection\n\n" -"All instances of the connection object and derived objects\n" -"(queries and large objects) can no longer be used after this call.\n"; + "close() -- close connection\n\n" + "All instances of the connection object and derived objects\n" + "(queries and large objects) can no longer be used after this call.\n"; static PyObject * conn_close(connObject *self, PyObject *noargs) @@ -1590,7 +1658,7 @@ conn_close(connObject *self, PyObject *noargs) /* Get asynchronous notify. */ static char conn_get_notify__doc__[] = -"getnotify() -- get database notify for this connection"; + "getnotify() -- get database notify for this connection"; static PyObject * conn_get_notify(connObject *self, PyObject *noargs) @@ -1649,87 +1717,74 @@ conn_dir(connObject *self, PyObject *noargs) { PyObject *attrs; - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[sssssssssssss]", - "host", "port", "db", "options", "error", "status", "user", - "protocol_version", "server_version", "socket", "backend_pid", - "ssl_in_use", "ssl_attributes"); + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[sssssssssssss]", "host", "port", + "db", "options", "error", "status", "user", + "protocol_version", "server_version", "socket", + "backend_pid", "ssl_in_use", "ssl_attributes"); return attrs; } /* Connection object methods */ static struct PyMethodDef conn_methods[] = { - {"__dir__", (PyCFunction) conn_dir, METH_NOARGS, NULL}, - - {"source", (PyCFunction) conn_source, - METH_NOARGS, conn_source__doc__}, - {"query", (PyCFunction) conn_query, - METH_VARARGS, conn_query__doc__}, - {"send_query", (PyCFunction) conn_send_query, - METH_VARARGS, conn_send_query__doc__}, - {"query_prepared", (PyCFunction) conn_query_prepared, - METH_VARARGS, conn_query_prepared__doc__}, - {"prepare", (PyCFunction) conn_prepare, - METH_VARARGS, conn_prepare__doc__}, - {"describe_prepared", (PyCFunction) conn_describe_prepared, - METH_VARARGS, conn_describe_prepared__doc__}, - {"poll", (PyCFunction) conn_poll, - METH_NOARGS, conn_poll__doc__}, - {"reset", (PyCFunction) conn_reset, - METH_NOARGS, conn_reset__doc__}, - {"cancel", (PyCFunction) conn_cancel, - METH_NOARGS, conn_cancel__doc__}, - {"close", (PyCFunction) conn_close, - METH_NOARGS, conn_close__doc__}, - {"fileno", (PyCFunction) conn_fileno, - METH_NOARGS, conn_fileno__doc__}, - {"get_cast_hook", (PyCFunction) conn_get_cast_hook, - METH_NOARGS, conn_get_cast_hook__doc__}, - {"set_cast_hook", (PyCFunction) conn_set_cast_hook, - METH_O, conn_set_cast_hook__doc__}, - {"get_notice_receiver", (PyCFunction) conn_get_notice_receiver, - METH_NOARGS, conn_get_notice_receiver__doc__}, - {"set_notice_receiver", (PyCFunction) conn_set_notice_receiver, - METH_O, conn_set_notice_receiver__doc__}, - {"getnotify", (PyCFunction) conn_get_notify, - METH_NOARGS, conn_get_notify__doc__}, - {"inserttable", (PyCFunction) conn_inserttable, - METH_VARARGS, conn_inserttable__doc__}, - {"transaction", (PyCFunction) conn_transaction, - METH_NOARGS, conn_transaction__doc__}, - {"parameter", (PyCFunction) conn_parameter, - METH_VARARGS, conn_parameter__doc__}, - {"date_format", (PyCFunction) conn_date_format, - METH_NOARGS, conn_date_format__doc__}, - - {"escape_literal", (PyCFunction) conn_escape_literal, - METH_O, conn_escape_literal__doc__}, - {"escape_identifier", (PyCFunction) conn_escape_identifier, - METH_O, conn_escape_identifier__doc__}, - {"escape_string", (PyCFunction) conn_escape_string, - METH_O, conn_escape_string__doc__}, - {"escape_bytea", (PyCFunction) conn_escape_bytea, - METH_O, conn_escape_bytea__doc__}, - - {"putline", (PyCFunction) conn_putline, - METH_VARARGS, conn_putline__doc__}, - {"getline", (PyCFunction) conn_getline, - METH_NOARGS, conn_getline__doc__}, - {"endcopy", (PyCFunction) conn_endcopy, - METH_NOARGS, conn_endcopy__doc__}, - {"set_non_blocking", (PyCFunction) conn_set_non_blocking, - METH_VARARGS, conn_set_non_blocking__doc__}, - {"is_non_blocking", (PyCFunction) conn_is_non_blocking, - METH_NOARGS, conn_is_non_blocking__doc__}, - - {"locreate", (PyCFunction) conn_locreate, - METH_VARARGS, conn_locreate__doc__}, - {"getlo", (PyCFunction) conn_getlo, - METH_VARARGS, conn_getlo__doc__}, - {"loimport", (PyCFunction) conn_loimport, - METH_VARARGS, conn_loimport__doc__}, + {"__dir__", (PyCFunction)conn_dir, METH_NOARGS, NULL}, + + {"source", (PyCFunction)conn_source, METH_NOARGS, conn_source__doc__}, + {"query", (PyCFunction)conn_query, METH_VARARGS, conn_query__doc__}, + {"send_query", (PyCFunction)conn_send_query, METH_VARARGS, + conn_send_query__doc__}, + {"query_prepared", (PyCFunction)conn_query_prepared, METH_VARARGS, + conn_query_prepared__doc__}, + {"prepare", (PyCFunction)conn_prepare, METH_VARARGS, conn_prepare__doc__}, + {"describe_prepared", (PyCFunction)conn_describe_prepared, METH_VARARGS, + conn_describe_prepared__doc__}, + {"poll", (PyCFunction)conn_poll, METH_NOARGS, conn_poll__doc__}, + {"reset", (PyCFunction)conn_reset, METH_NOARGS, conn_reset__doc__}, + {"cancel", (PyCFunction)conn_cancel, METH_NOARGS, conn_cancel__doc__}, + {"close", (PyCFunction)conn_close, METH_NOARGS, conn_close__doc__}, + {"fileno", (PyCFunction)conn_fileno, METH_NOARGS, conn_fileno__doc__}, + {"get_cast_hook", (PyCFunction)conn_get_cast_hook, METH_NOARGS, + conn_get_cast_hook__doc__}, + {"set_cast_hook", (PyCFunction)conn_set_cast_hook, METH_O, + conn_set_cast_hook__doc__}, + {"get_notice_receiver", (PyCFunction)conn_get_notice_receiver, METH_NOARGS, + conn_get_notice_receiver__doc__}, + {"set_notice_receiver", (PyCFunction)conn_set_notice_receiver, METH_O, + conn_set_notice_receiver__doc__}, + {"getnotify", (PyCFunction)conn_get_notify, METH_NOARGS, + conn_get_notify__doc__}, + {"inserttable", (PyCFunction)conn_inserttable, METH_VARARGS, + conn_inserttable__doc__}, + {"transaction", (PyCFunction)conn_transaction, METH_NOARGS, + conn_transaction__doc__}, + {"parameter", (PyCFunction)conn_parameter, METH_VARARGS, + conn_parameter__doc__}, + {"date_format", (PyCFunction)conn_date_format, METH_NOARGS, + conn_date_format__doc__}, + + {"escape_literal", (PyCFunction)conn_escape_literal, METH_O, + conn_escape_literal__doc__}, + {"escape_identifier", (PyCFunction)conn_escape_identifier, METH_O, + conn_escape_identifier__doc__}, + {"escape_string", (PyCFunction)conn_escape_string, METH_O, + conn_escape_string__doc__}, + {"escape_bytea", (PyCFunction)conn_escape_bytea, METH_O, + conn_escape_bytea__doc__}, + + {"putline", (PyCFunction)conn_putline, METH_VARARGS, conn_putline__doc__}, + {"getline", (PyCFunction)conn_getline, METH_NOARGS, conn_getline__doc__}, + {"endcopy", (PyCFunction)conn_endcopy, METH_NOARGS, conn_endcopy__doc__}, + {"set_non_blocking", (PyCFunction)conn_set_non_blocking, METH_VARARGS, + conn_set_non_blocking__doc__}, + {"is_non_blocking", (PyCFunction)conn_is_non_blocking, METH_NOARGS, + conn_is_non_blocking__doc__}, + + {"locreate", (PyCFunction)conn_locreate, METH_VARARGS, + conn_locreate__doc__}, + {"getlo", (PyCFunction)conn_getlo, METH_VARARGS, conn_getlo__doc__}, + {"loimport", (PyCFunction)conn_loimport, METH_VARARGS, + conn_loimport__doc__}, {NULL, NULL} /* sentinel */ }; @@ -1738,32 +1793,31 @@ static char conn__doc__[] = "PostgreSQL connection object"; /* Connection type definition */ static PyTypeObject connType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.Connection", /* tp_name */ - sizeof(connObject), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor) conn_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - (getattrofunc) conn_getattr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - conn__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - conn_methods, /* tp_methods */ + PyVarObject_HEAD_INIT(NULL, 0) "pg.Connection", /* tp_name */ + sizeof(connObject), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)conn_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + (getattrofunc)conn_getattr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + conn__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + conn_methods, /* tp_methods */ }; diff --git a/pginternal.c b/pginternal.c index 61446f41..124661c1 100644 --- a/pginternal.c +++ b/pginternal.c @@ -37,8 +37,8 @@ get_decoded_string(const char *str, Py_ssize_t size, int encoding) if (encoding == pg_encoding_ascii) return PyUnicode_DecodeASCII(str, size, "strict"); /* encoding name should be properly translated to Python here */ - return PyUnicode_Decode(str, size, - pg_encoding_to_char(encoding), "strict"); + return PyUnicode_Decode(str, size, pg_encoding_to_char(encoding), + "strict"); } static PyObject * @@ -52,7 +52,7 @@ get_encoded_string(PyObject *unicode_obj, int encoding) return PyUnicode_AsASCIIString(unicode_obj); /* encoding name should be properly translated to Python here */ return PyUnicode_AsEncodedString(unicode_obj, - pg_encoding_to_char(encoding), "strict"); + pg_encoding_to_char(encoding), "strict"); } /* Helper functions */ @@ -64,7 +64,7 @@ get_type(Oid pgtype) int t; switch (pgtype) { - /* simple types */ + /* simple types */ case INT2OID: case INT4OID: @@ -113,7 +113,7 @@ get_type(Oid pgtype) t = PYGRES_TEXT; break; - /* array types */ + /* array types */ case INT2ARRAYOID: case INT4ARRAYOID: @@ -137,8 +137,9 @@ get_type(Oid pgtype) break; case MONEYARRAYOID: - t = array_as_text ? PYGRES_TEXT : ((decimal_point ? - PYGRES_MONEY : PYGRES_TEXT) | PYGRES_ARRAY); + t = array_as_text ? PYGRES_TEXT + : ((decimal_point ? PYGRES_MONEY : PYGRES_TEXT) | + PYGRES_ARRAY); break; case BOOLARRAYOID: @@ -146,14 +147,16 @@ get_type(Oid pgtype) break; case BYTEAARRAYOID: - t = array_as_text ? PYGRES_TEXT : ((bytea_escaped ? - PYGRES_TEXT : PYGRES_BYTEA) | PYGRES_ARRAY); + t = array_as_text ? PYGRES_TEXT + : ((bytea_escaped ? PYGRES_TEXT : PYGRES_BYTEA) | + PYGRES_ARRAY); break; case JSONARRAYOID: case JSONBARRAYOID: - t = array_as_text ? PYGRES_TEXT : ((jsondecode ? - PYGRES_JSON : PYGRES_TEXT) | PYGRES_ARRAY); + t = array_as_text ? PYGRES_TEXT + : ((jsondecode ? PYGRES_JSON : PYGRES_TEXT) | + PYGRES_ARRAY); break; case BPCHARARRAYOID: @@ -178,8 +181,8 @@ get_col_types(PGresult *result, int nfields) { int *types, *t, j; - if (!(types = PyMem_Malloc(sizeof(int) * (size_t) nfields))) { - return (int*) PyErr_NoMemory(); + if (!(types = PyMem_Malloc(sizeof(int) * (size_t)nfields))) { + return (int *)PyErr_NoMemory(); } for (j = 0, t = types; j < nfields; ++j) { @@ -199,8 +202,8 @@ cast_bytea_text(char *s) size_t str_len; /* this function should not be called when bytea_escaped is set */ - tmp_str = (char *) PQunescapeBytea((unsigned char*) s, &str_len); - obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t) str_len); + tmp_str = (char *)PQunescapeBytea((unsigned char *)s, &str_len); + obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t)str_len); if (tmp_str) { PQfreemem(tmp_str); } @@ -221,16 +224,18 @@ cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) case PYGRES_BYTEA: /* this type should not be passed when bytea_escaped is set */ /* we need to add a null byte */ - tmp_str = (char *) PyMem_Malloc((size_t) size + 1); + tmp_str = (char *)PyMem_Malloc((size_t)size + 1); if (!tmp_str) { return PyErr_NoMemory(); } - memcpy(tmp_str, s, (size_t) size); - s = tmp_str; *(s + size) = '\0'; - tmp_str = (char *) PQunescapeBytea((unsigned char*) s, &str_len); + memcpy(tmp_str, s, (size_t)size); + s = tmp_str; + *(s + size) = '\0'; + tmp_str = (char *)PQunescapeBytea((unsigned char *)s, &str_len); PyMem_Free(s); - if (!tmp_str) return PyErr_NoMemory(); - obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t) str_len); + if (!tmp_str) + return PyErr_NoMemory(); + obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t)str_len); if (tmp_str) { PQfreemem(tmp_str); } @@ -246,7 +251,7 @@ cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) } break; - default: /* PYGRES_TEXT */ + default: /* PYGRES_TEXT */ obj = get_decoded_string(s, size, encoding); if (!obj) { /* cannot decode */ obj = PyBytes_FromStringAndSize(s, size); @@ -288,8 +293,8 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) case PYGRES_INT: n = sizeof(buf) / sizeof(buf[0]) - 1; - if ((int) size < n) { - n = (int) size; + if ((int)size < n) { + n = (int)size; } for (i = 0, t = buf; i < n; ++i) { *t++ = *s++; @@ -300,8 +305,8 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) case PYGRES_LONG: n = sizeof(buf) / sizeof(buf[0]) - 1; - if ((int) size < n) { - n = (int) size; + if ((int)size < n) { + n = (int)size; } for (i = 0, t = buf; i < n; ++i) { *t++ = *s++; @@ -338,14 +343,14 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) tmp_obj = PyUnicode_FromString(buf); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); - } break; case PYGRES_DECIMAL: tmp_obj = PyUnicode_FromStringAndSize(s, size); - obj = decimal ? PyObject_CallFunctionObjArgs( - decimal, tmp_obj, NULL) : PyFloat_FromString(tmp_obj); + obj = decimal + ? PyObject_CallFunctionObjArgs(decimal, tmp_obj, NULL) + : PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); break; @@ -404,7 +409,8 @@ cast_unsized_simple(char *s, int type) buf[j++] = '-'; } } - buf[j] = '\0'; s = buf; + buf[j] = '\0'; + s = buf; /* FALLTHROUGH */ /* no break here */ case PYGRES_DECIMAL: @@ -438,11 +444,10 @@ cast_unsized_simple(char *s, int type) } /* Quick case insensitive check if given sized string is null. */ -#define STR_IS_NULL(s, n) (n == 4 && \ - (s[0] == 'n' || s[0] == 'N') && \ - (s[1] == 'u' || s[1] == 'U') && \ - (s[2] == 'l' || s[2] == 'L') && \ - (s[3] == 'l' || s[3] == 'L')) +#define STR_IS_NULL(s, n) \ + (n == 4 && (s[0] == 'n' || s[0] == 'N') && \ + (s[1] == 'u' || s[1] == 'U') && (s[2] == 'l' || s[2] == 'L') && \ + (s[3] == 'l' || s[3] == 'L')) /* Cast string s with size and encoding to a Python list, using the input and output syntax for arrays. @@ -450,8 +455,8 @@ cast_unsized_simple(char *s, int type) The parameter delim specifies the delimiter for the elements, since some types do not use the default delimiter of a comma. */ static PyObject * -cast_array(char *s, Py_ssize_t size, int encoding, - int type, PyObject *cast, char delim) +cast_array(char *s, Py_ssize_t size, int encoding, int type, PyObject *cast, + char delim) { PyObject *result, *stack[MAX_ARRAY_DEPTH]; char *end = s + size, *t; @@ -459,12 +464,13 @@ cast_array(char *s, Py_ssize_t size, int encoding, if (type) { type &= ~PYGRES_ARRAY; /* get the base type */ - if (!type) type = PYGRES_TEXT; + if (!type) + type = PYGRES_TEXT; } if (!delim) { delim = ','; } - else if (delim == '{' || delim =='}' || delim=='\\') { + else if (delim == '{' || delim == '}' || delim == '\\') { PyErr_SetString(PyExc_ValueError, "Invalid array delimiter"); return NULL; } @@ -475,20 +481,28 @@ cast_array(char *s, Py_ssize_t size, int encoding, int valid; for (valid = 0; !valid;) { - if (s == end || *s++ != '[') break; + if (s == end || *s++ != '[') + break; while (s != end && *s == ' ') ++s; - if (s != end && (*s == '+' || *s == '-')) ++s; - if (s == end || *s < '0' || *s > '9') break; + if (s != end && (*s == '+' || *s == '-')) + ++s; + if (s == end || *s < '0' || *s > '9') + break; while (s != end && *s >= '0' && *s <= '9') ++s; - if (s == end || *s++ != ':') break; - if (s != end && (*s == '+' || *s == '-')) ++s; - if (s == end || *s < '0' || *s > '9') break; + if (s == end || *s++ != ':') + break; + if (s != end && (*s == '+' || *s == '-')) + ++s; + if (s == end || *s < '0' || *s > '9') + break; while (s != end && *s >= '0' && *s <= '9') ++s; - if (s == end || *s++ != ']') break; + if (s == end || *s++ != ']') + break; while (s != end && *s == ' ') ++s; ++ranges; if (s != end && *s == '=') { - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); valid = 1; } } @@ -498,7 +512,8 @@ cast_array(char *s, Py_ssize_t size, int encoding, } } for (t = s, depth = 0; t != end && (*t == '{' || *t == ' '); ++t) { - if (*t == '{') ++depth; + if (*t == '{') + ++depth; } if (!depth) { PyErr_SetString(PyExc_ValueError, @@ -516,30 +531,40 @@ cast_array(char *s, Py_ssize_t size, int encoding, } depth--; /* next level of parsing */ result = PyList_New(0); - if (!result) return NULL; - do ++s; while (s != end && *s == ' '); + if (!result) + return NULL; + do ++s; + while (s != end && *s == ' '); /* everything is set up, start parsing the array */ while (s != end) { if (*s == '}') { PyObject *subresult; - if (!level) break; /* top level array ended */ - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + if (!level) + break; /* top level array ended */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ if (*s == delim) { - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ if (*s != '{') { PyErr_SetString(PyExc_ValueError, "Subarray expected but not found"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } - else if (*s != '}') break; /* error */ + else if (*s != '}') + break; /* error */ subresult = result; result = stack[--level]; if (PyList_Append(result, subresult)) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else if (level == depth) { /* we expect elements at this level */ @@ -551,40 +576,48 @@ cast_array(char *s, Py_ssize_t size, int encoding, if (*s == '{') { PyErr_SetString(PyExc_ValueError, "Subarray found where not expected"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } if (*s == '"') { /* quoted element */ estr = ++s; while (s != end && *s != '"') { if (*s == '\\') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; escaped = 1; } ++s; } esize = s - estr; - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); } else { /* unquoted element */ estr = s; /* can contain blanks inside */ - while (s != end && *s != '"' && - *s != '{' && *s != '}' && *s != delim) - { + while (s != end && *s != '"' && *s != '{' && *s != '}' && + *s != delim) { if (*s == '\\') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; escaped = 1; } ++s; } - t = s; while (t > estr && *(t - 1) == ' ') --t; + t = s; + while (t > estr && *(t - 1) == ' ') --t; if (!(esize = t - estr)) { - s = end; break; /* error */ + s = end; + break; /* error */ } if (STR_IS_NULL(estr, esize)) /* NULL gives None */ estr = NULL; } - if (s == end) break; /* error */ + if (s == end) + break; /* error */ if (estr) { if (escaped) { char *r; @@ -592,12 +625,14 @@ cast_array(char *s, Py_ssize_t size, int encoding, /* create unescaped string */ t = estr; - estr = (char *) PyMem_Malloc((size_t) esize); + estr = (char *)PyMem_Malloc((size_t)esize); if (!estr) { - Py_DECREF(result); return PyErr_NoMemory(); + Py_DECREF(result); + return PyErr_NoMemory(); } for (i = 0, r = estr; i < esize; ++i) { - if (*t == '\\') ++t, ++i; + if (*t == '\\') + ++t, ++i; *r++ = *t++; } esize = r - estr; @@ -609,58 +644,73 @@ cast_array(char *s, Py_ssize_t size, int encoding, element = cast_sized_simple(estr, esize, type); } else { /* external casting of base type */ - element = encoding == pg_encoding_ascii ? NULL : - get_decoded_string(estr, esize, encoding); + element = encoding == pg_encoding_ascii + ? NULL + : get_decoded_string(estr, esize, encoding); if (!element) { /* no decoding necessary or possible */ element = PyBytes_FromStringAndSize(estr, esize); } if (element && cast) { PyObject *tmp = element; - element = PyObject_CallFunctionObjArgs( - cast, element, NULL); + element = + PyObject_CallFunctionObjArgs(cast, element, NULL); Py_DECREF(tmp); } } - if (escaped) PyMem_Free(estr); + if (escaped) + PyMem_Free(estr); if (!element) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else { - Py_INCREF(Py_None); element = Py_None; + Py_INCREF(Py_None); + element = Py_None; } if (PyList_Append(result, element)) { - Py_DECREF(element); Py_DECREF(result); return NULL; + Py_DECREF(element); + Py_DECREF(result); + return NULL; } Py_DECREF(element); if (*s == delim) { - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ } - else if (*s != '}') break; /* error */ + else if (*s != '}') + break; /* error */ } else { /* we expect arrays at this level */ if (*s != '{') { PyErr_SetString(PyExc_ValueError, "Subarray must start with a left brace"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ stack[level++] = result; - if (!(result = PyList_New(0))) return NULL; + if (!(result = PyList_New(0))) + return NULL; } } if (s == end || *s != '}') { - PyErr_SetString(PyExc_ValueError, - "Unexpected end of array"); - Py_DECREF(result); return NULL; + PyErr_SetString(PyExc_ValueError, "Unexpected end of array"); + Py_DECREF(result); + return NULL; } - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); if (s != end) { PyErr_SetString(PyExc_ValueError, "Unexpected characters after end of array"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } return result; } @@ -672,8 +722,8 @@ cast_array(char *s, Py_ssize_t size, int encoding, The parameter delim can specify a delimiter for the elements, although composite types always use a comma as delimiter. */ static PyObject * -cast_record(char *s, Py_ssize_t size, int encoding, - int *type, PyObject *cast, Py_ssize_t len, char delim) +cast_record(char *s, Py_ssize_t size, int encoding, int *type, PyObject *cast, + Py_ssize_t len, char delim) { PyObject *result, *ret; char *end = s + size, *t; @@ -682,7 +732,7 @@ cast_record(char *s, Py_ssize_t size, int encoding, if (!delim) { delim = ','; } - else if (delim == '(' || delim ==')' || delim=='\\') { + else if (delim == '(' || delim == ')' || delim == '\\') { PyErr_SetString(PyExc_ValueError, "Invalid record delimiter"); return NULL; } @@ -695,14 +745,16 @@ cast_record(char *s, Py_ssize_t size, int encoding, return NULL; } result = PyList_New(0); - if (!result) return NULL; + if (!result) + return NULL; i = 0; /* everything is set up, start parsing the record */ while (++s != end) { PyObject *element; if (*s == ')' || *s == delim) { - Py_INCREF(Py_None); element = Py_None; + Py_INCREF(Py_None); + element = Py_None; } else { char *estr; @@ -711,32 +763,40 @@ cast_record(char *s, Py_ssize_t size, int encoding, estr = s; quoted = *s == '"'; - if (quoted) ++s; + if (quoted) + ++s; esize = 0; while (s != end) { if (!quoted && (*s == ')' || *s == delim)) break; if (*s == '"') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; if (!(quoted && *s == '"')) { - quoted = !quoted; continue; + quoted = !quoted; + continue; } } if (*s == '\\') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; } ++s, ++esize; } - if (s == end) break; /* error */ + if (s == end) + break; /* error */ if (estr + esize != s) { char *r; escaped = 1; /* create unescaped string */ t = estr; - estr = (char *) PyMem_Malloc((size_t) esize); + estr = (char *)PyMem_Malloc((size_t)esize); if (!estr) { - Py_DECREF(result); return PyErr_NoMemory(); + Py_DECREF(result); + return PyErr_NoMemory(); } quoted = 0; r = estr; @@ -744,10 +804,12 @@ cast_record(char *s, Py_ssize_t size, int encoding, if (*t == '"') { ++t; if (!(quoted && *t == '"')) { - quoted = !quoted; continue; + quoted = !quoted; + continue; } } - if (*t == '\\') ++t; + if (*t == '\\') + ++t; *r++ = *t++; } } @@ -755,16 +817,17 @@ cast_record(char *s, Py_ssize_t size, int encoding, int etype = type[i]; if (etype & PYGRES_ARRAY) - element = cast_array( - estr, esize, encoding, etype, NULL, 0); + element = + cast_array(estr, esize, encoding, etype, NULL, 0); else if (etype & PYGRES_TEXT) element = cast_sized_text(estr, esize, encoding, etype); else element = cast_sized_simple(estr, esize, etype); } else { /* external casting of base type */ - element = encoding == pg_encoding_ascii ? NULL : - get_decoded_string(estr, esize, encoding); + element = encoding == pg_encoding_ascii + ? NULL + : get_decoded_string(estr, esize, encoding); if (!element) { /* no decoding necessary or possible */ element = PyBytes_FromStringAndSize(estr, esize); } @@ -781,46 +844,58 @@ cast_record(char *s, Py_ssize_t size, int encoding, } } else { - Py_DECREF(element); element = NULL; + Py_DECREF(element); + element = NULL; } } else { PyObject *tmp = element; - element = PyObject_CallFunctionObjArgs( - cast, element, NULL); + element = + PyObject_CallFunctionObjArgs(cast, element, NULL); Py_DECREF(tmp); } } } - if (escaped) PyMem_Free(estr); + if (escaped) + PyMem_Free(estr); if (!element) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } if (PyList_Append(result, element)) { - Py_DECREF(element); Py_DECREF(result); return NULL; + Py_DECREF(element); + Py_DECREF(result); + return NULL; } Py_DECREF(element); - if (len) ++i; - if (*s != delim) break; /* no next record */ + if (len) + ++i; + if (*s != delim) + break; /* no next record */ if (len && i >= len) { PyErr_SetString(PyExc_ValueError, "Too many columns"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } if (s == end || *s != ')') { PyErr_SetString(PyExc_ValueError, "Unexpected end of record"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); if (s != end) { PyErr_SetString(PyExc_ValueError, "Unexpected characters after end of record"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } if (len && i < len) { PyErr_SetString(PyExc_ValueError, "Too few columns"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } ret = PyList_AsTuple(result); @@ -846,94 +921,116 @@ cast_hstore(char *s, Py_ssize_t size, int encoding) int quoted; while (s != end && *s == ' ') ++s; - if (s == end) break; + if (s == end) + break; quoted = *s == '"'; if (quoted) { key = ++s; while (s != end) { - if (*s == '"') break; + if (*s == '"') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++key_esc; } ++s; } if (s == end) { PyErr_SetString(PyExc_ValueError, "Unterminated quote"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else { key = s; while (s != end) { - if (*s == '=' || *s == ' ') break; + if (*s == '=' || *s == ' ') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++key_esc; } ++s; } if (s == key) { PyErr_SetString(PyExc_ValueError, "Missing key"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } size = s - key - key_esc; if (key_esc) { char *r = key, *t; - key = (char *) PyMem_Malloc((size_t) size); + key = (char *)PyMem_Malloc((size_t)size); if (!key) { - Py_DECREF(result); return PyErr_NoMemory(); + Py_DECREF(result); + return PyErr_NoMemory(); } t = key; while (r != s) { if (*r == '\\') { - ++r; if (r == s) break; + ++r; + if (r == s) + break; } *t++ = *r++; } } key_obj = cast_sized_text(key, size, encoding, PYGRES_TEXT); - if (key_esc) PyMem_Free(key); + if (key_esc) + PyMem_Free(key); if (!key_obj) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } - if (quoted) ++s; + if (quoted) + ++s; while (s != end && *s == ' ') ++s; if (s == end || *s++ != '=' || s == end || *s++ != '>') { PyErr_SetString(PyExc_ValueError, "Invalid characters after key"); - Py_DECREF(key_obj); Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(result); + return NULL; } while (s != end && *s == ' ') ++s; quoted = *s == '"'; if (quoted) { val = ++s; while (s != end) { - if (*s == '"') break; + if (*s == '"') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++val_esc; } ++s; } if (s == end) { PyErr_SetString(PyExc_ValueError, "Unterminated quote"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else { val = s; while (s != end) { - if (*s == ',' || *s == ' ') break; + if (*s == ',' || *s == ' ') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++val_esc; } ++s; } if (s == val) { PyErr_SetString(PyExc_ValueError, "Missing value"); - Py_DECREF(key_obj); Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(result); + return NULL; } if (STR_IS_NULL(val, s - val)) val = NULL; @@ -942,46 +1039,59 @@ cast_hstore(char *s, Py_ssize_t size, int encoding) size = s - val - val_esc; if (val_esc) { char *r = val, *t; - val = (char *) PyMem_Malloc((size_t) size); + val = (char *)PyMem_Malloc((size_t)size); if (!val) { - Py_DECREF(key_obj); Py_DECREF(result); + Py_DECREF(key_obj); + Py_DECREF(result); return PyErr_NoMemory(); } t = val; while (r != s) { if (*r == '\\') { - ++r; if (r == s) break; + ++r; + if (r == s) + break; } *t++ = *r++; } } val_obj = cast_sized_text(val, size, encoding, PYGRES_TEXT); - if (val_esc) PyMem_Free(val); + if (val_esc) + PyMem_Free(val); if (!val_obj) { - Py_DECREF(key_obj); Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(result); + return NULL; } } else { - Py_INCREF(Py_None); val_obj = Py_None; + Py_INCREF(Py_None); + val_obj = Py_None; } - if (quoted) ++s; + if (quoted) + ++s; while (s != end && *s == ' ') ++s; if (s != end) { if (*s++ != ',') { PyErr_SetString(PyExc_ValueError, "Invalid characters after val"); - Py_DECREF(key_obj); Py_DECREF(val_obj); - Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(val_obj); + Py_DECREF(result); + return NULL; } while (s != end && *s == ' ') ++s; if (s == end) { PyErr_SetString(PyExc_ValueError, "Missing entry"); - Py_DECREF(key_obj); Py_DECREF(val_obj); - Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(val_obj); + Py_DECREF(result); + return NULL; } } PyDict_SetItem(result, key_obj, val_obj); - Py_DECREF(key_obj); Py_DECREF(val_obj); + Py_DECREF(key_obj); + Py_DECREF(val_obj); } return result; } @@ -1054,15 +1164,15 @@ get_error_type(const char *sqlstate) /* Set database error message and sqlstate attribute. */ static void -set_error_msg_and_state(PyObject *type, - const char *msg, int encoding, const char *sqlstate) +set_error_msg_and_state(PyObject *type, const char *msg, int encoding, + const char *sqlstate) { PyObject *err_obj, *msg_obj, *sql_obj = NULL; if (encoding == -1) /* unknown */ msg_obj = PyUnicode_DecodeLocale(msg, NULL); else - msg_obj = get_decoded_string(msg, (Py_ssize_t) strlen(msg), encoding); + msg_obj = get_decoded_string(msg, (Py_ssize_t)strlen(msg), encoding); if (!msg_obj) /* cannot decode */ msg_obj = PyBytes_FromString(msg); @@ -1070,7 +1180,8 @@ set_error_msg_and_state(PyObject *type, sql_obj = PyUnicode_FromStringAndSize(sqlstate, 5); } else { - Py_INCREF(Py_None); sql_obj = Py_None; + Py_INCREF(Py_None); + sql_obj = Py_None; } err_obj = PyObject_CallFunctionObjArgs(type, msg_obj, NULL); @@ -1095,7 +1206,7 @@ set_error_msg(PyObject *type, const char *msg) /* Set database error from connection and/or result. */ static void -set_error(PyObject *type, const char * msg, PGconn *cnx, PGresult *result) +set_error(PyObject *type, const char *msg, PGconn *cnx, PGresult *result) { char *sqlstate = NULL; int encoding = pg_encoding_ascii; @@ -1109,7 +1220,8 @@ set_error(PyObject *type, const char * msg, PGconn *cnx, PGresult *result) } if (result) { sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE); - if (sqlstate) type = get_error_type(sqlstate); + if (sqlstate) + type = get_error_type(sqlstate); } set_error_msg_and_state(type, msg, encoding, sqlstate); @@ -1117,9 +1229,10 @@ set_error(PyObject *type, const char * msg, PGconn *cnx, PGresult *result) /* Get SSL attributes and values as a dictionary. */ static PyObject * -get_ssl_attributes(PGconn *cnx) { +get_ssl_attributes(PGconn *cnx) +{ PyObject *attr_dict = NULL; - const char * const *s; + const char *const *s; if (!(attr_dict = PyDict_New())) { return NULL; @@ -1129,7 +1242,7 @@ get_ssl_attributes(PGconn *cnx) { const char *val = PQsslAttribute(cnx, *s); if (val) { - PyObject * val_obj = PyUnicode_FromString(val); + PyObject *val_obj = PyUnicode_FromString(val); PyDict_SetItemString(attr_dict, *s, val_obj); Py_DECREF(val_obj); @@ -1153,10 +1266,10 @@ format_result(const PGresult *res) const int n = PQnfields(res); if (n > 0) { - char * const aligns = (char *) PyMem_Malloc( - (unsigned int) n * sizeof(char)); - size_t * const sizes = (size_t *) PyMem_Malloc( - (unsigned int) n * sizeof(size_t)); + char *const aligns = + (char *)PyMem_Malloc((unsigned int)n * sizeof(char)); + size_t *const sizes = + (size_t *)PyMem_Malloc((unsigned int)n * sizeof(size_t)); if (aligns && sizes) { const int m = PQntuples(res); @@ -1166,7 +1279,7 @@ format_result(const PGresult *res) /* calculate sizes and alignments */ for (j = 0; j < n; ++j) { - const char * const s = PQfname(res, j); + const char *const s = PQfname(res, j); const int format = PQfformat(res, j); sizes[j] = s ? strlen(s) : 0; @@ -1202,9 +1315,9 @@ format_result(const PGresult *res) if (aligns[j]) { const int k = PQgetlength(res, i, j); - if (sizes[j] < (size_t) k) + if (sizes[j] < (size_t)k) /* value must fit */ - sizes[j] = (size_t) k; + sizes[j] = (size_t)k; } } } @@ -1212,23 +1325,23 @@ format_result(const PGresult *res) /* size of one row */ for (j = 0; j < n; ++j) size += sizes[j] + 1; /* times number of rows incl. heading */ - size *= (size_t) m + 2; + size *= (size_t)m + 2; /* plus size of footer */ size += 40; /* is the buffer size that needs to be allocated */ - buffer = (char *) PyMem_Malloc(size); + buffer = (char *)PyMem_Malloc(size); if (buffer) { char *p = buffer; PyObject *result; /* create the header */ for (j = 0; j < n; ++j) { - const char * const s = PQfname(res, j); + const char *const s = PQfname(res, j); const size_t k = sizes[j]; - const size_t h = (k - (size_t) strlen(s)) / 2; + const size_t h = (k - (size_t)strlen(s)) / 2; - sprintf(p, "%*s", (int) h, ""); - sprintf(p + h, "%-*s", (int) (k - h), s); + sprintf(p, "%*s", (int)h, ""); + sprintf(p + h, "%-*s", (int)(k - h), s); p += k; if (j + 1 < n) *p++ = '|'; @@ -1237,8 +1350,7 @@ format_result(const PGresult *res) for (j = 0; j < n; ++j) { size_t k = sizes[j]; - while (k--) - *p++ = '-'; + while (k--) *p++ = '-'; if (j + 1 < n) *p++ = '+'; } @@ -1250,11 +1362,11 @@ format_result(const PGresult *res) const size_t k = sizes[j]; if (align) { - sprintf(p, align == 'r' ? "%*s" : "%-*s", (int) k, + sprintf(p, align == 'r' ? "%*s" : "%-*s", (int)k, PQgetvalue(res, i, j)); } else { - sprintf(p, "%-*s", (int) k, + sprintf(p, "%-*s", (int)k, PQgetisnull(res, i, j) ? "" : ""); } p += k; @@ -1264,7 +1376,8 @@ format_result(const PGresult *res) *p++ = '\n'; } /* free memory */ - PyMem_Free(aligns); PyMem_Free(sizes); + PyMem_Free(aligns); + PyMem_Free(sizes); /* create the footer */ sprintf(p, "(%d row%s)", m, m == 1 ? "" : "s"); /* return the result */ @@ -1273,11 +1386,15 @@ format_result(const PGresult *res) return result; } else { - PyMem_Free(aligns); PyMem_Free(sizes); return PyErr_NoMemory(); + PyMem_Free(aligns); + PyMem_Free(sizes); + return PyErr_NoMemory(); } } else { - PyMem_Free(aligns); PyMem_Free(sizes); return PyErr_NoMemory(); + PyMem_Free(aligns); + PyMem_Free(sizes); + return PyErr_NoMemory(); } } else @@ -1288,28 +1405,31 @@ format_result(const PGresult *res) static const char * date_style_to_format(const char *s) { - static const char *formats[] = - { - "%Y-%m-%d", /* 0 = ISO */ - "%m-%d-%Y", /* 1 = Postgres, MDY */ - "%d-%m-%Y", /* 2 = Postgres, DMY */ - "%m/%d/%Y", /* 3 = SQL, MDY */ - "%d/%m/%Y", /* 4 = SQL, DMY */ - "%d.%m.%Y" /* 5 = German */ + static const char *formats[] = { + "%Y-%m-%d", /* 0 = ISO */ + "%m-%d-%Y", /* 1 = Postgres, MDY */ + "%d-%m-%Y", /* 2 = Postgres, DMY */ + "%m/%d/%Y", /* 3 = SQL, MDY */ + "%d/%m/%Y", /* 4 = SQL, DMY */ + "%d.%m.%Y" /* 5 = German */ }; switch (s ? *s : 'I') { case 'P': /* Postgres */ s = strchr(s + 1, ','); - if (s) do ++s; while (*s && *s == ' '); + if (s) + do ++s; + while (*s && *s == ' '); return formats[s && *s == 'D' ? 2 : 1]; case 'S': /* SQL */ s = strchr(s + 1, ','); - if (s) do ++s; while (*s && *s == ' '); + if (s) + do ++s; + while (*s && *s == ' '); return formats[s && *s == 'D' ? 4 : 3]; case 'G': /* German */ return formats[5]; - default: /* ISO */ + default: /* ISO */ return formats[0]; /* ISO is the default */ } } @@ -1318,14 +1438,13 @@ date_style_to_format(const char *s) static const char * date_format_to_style(const char *s) { - static const char *datestyle[] = - { - "ISO, YMD", /* 0 = %Y-%m-%d */ - "Postgres, MDY", /* 1 = %m-%d-%Y */ - "Postgres, DMY", /* 2 = %d-%m-%Y */ - "SQL, MDY", /* 3 = %m/%d/%Y */ - "SQL, DMY", /* 4 = %d/%m/%Y */ - "German, DMY" /* 5 = %d.%m.%Y */ + static const char *datestyle[] = { + "ISO, YMD", /* 0 = %Y-%m-%d */ + "Postgres, MDY", /* 1 = %m-%d-%Y */ + "Postgres, DMY", /* 2 = %d-%m-%Y */ + "SQL, MDY", /* 3 = %m/%d/%Y */ + "SQL, DMY", /* 4 = %d/%m/%Y */ + "German, DMY" /* 5 = %d.%m.%Y */ }; switch (s ? s[1] : 'Y') { @@ -1355,7 +1474,7 @@ static void notice_receiver(void *arg, const PGresult *res) { PyGILState_STATE gstate = PyGILState_Ensure(); - connObject *self = (connObject*) arg; + connObject *self = (connObject *)arg; PyObject *func = self->notice_receiver; if (func) { @@ -1367,7 +1486,7 @@ notice_receiver(void *arg, const PGresult *res) } else { Py_INCREF(Py_None); - notice = (noticeObject *)(void *) Py_None; + notice = (noticeObject *)(void *)Py_None; } ret = PyObject_CallFunction(func, "(O)", notice); Py_XDECREF(ret); diff --git a/pglarge.c b/pglarge.c index 863e2ec9..77455361 100644 --- a/pglarge.c +++ b/pglarge.c @@ -28,9 +28,10 @@ static PyObject * large_str(largeObject *self) { char str[80]; - sprintf(str, self->lo_fd >= 0 ? - "Opened large object, oid %ld" : - "Closed large object, oid %ld", (long) self->lo_oid); + sprintf(str, + self->lo_fd >= 0 ? "Opened large object, oid %ld" + : "Closed large object, oid %ld", + (long)self->lo_oid); return PyUnicode_FromString(str); } @@ -75,7 +76,7 @@ large_getattr(largeObject *self, PyObject *nameobj) if (!strcmp(name, "pgcnx")) { if (_check_lo_obj(self, 0)) { Py_INCREF(self->pgcnx); - return (PyObject *) (self->pgcnx); + return (PyObject *)(self->pgcnx); } PyErr_Clear(); Py_INCREF(Py_None); @@ -85,7 +86,7 @@ large_getattr(largeObject *self, PyObject *nameobj) /* large object oid */ if (!strcmp(name, "oid")) { if (_check_lo_obj(self, 0)) - return PyLong_FromLong((long) self->lo_oid); + return PyLong_FromLong((long)self->lo_oid); PyErr_Clear(); Py_INCREF(Py_None); return Py_None; @@ -96,7 +97,7 @@ large_getattr(largeObject *self, PyObject *nameobj) return PyUnicode_FromString(PQerrorMessage(self->pgcnx->cnx)); /* seeks name in methods (fallback) */ - return PyObject_GenericGetAttr((PyObject *) self, nameobj); + return PyObject_GenericGetAttr((PyObject *)self, nameobj); } /* Get the list of large object attributes. */ @@ -105,17 +106,16 @@ large_dir(largeObject *self, PyObject *noargs) { PyObject *attrs; - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[sss]", "oid", "pgcnx", "error"); + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[sss]", "oid", "pgcnx", "error"); return attrs; } /* Open large object. */ static char large_open__doc__[] = -"open(mode) -- open access to large object with specified mode\n\n" -"The mode must be one of INV_READ, INV_WRITE (module level constants).\n"; + "open(mode) -- open access to large object with specified mode\n\n" + "The mode must be one of INV_READ, INV_WRITE (module level constants).\n"; static PyObject * large_open(largeObject *self, PyObject *args) @@ -148,7 +148,7 @@ large_open(largeObject *self, PyObject *args) /* Close large object. */ static char large_close__doc__[] = -"close() -- close access to large object data"; + "close() -- close access to large object data"; static PyObject * large_close(largeObject *self, PyObject *noargs) @@ -172,8 +172,8 @@ large_close(largeObject *self, PyObject *noargs) /* Read from large object. */ static char large_read__doc__[] = -"read(size) -- read from large object to sized string\n\n" -"Object must be opened in read mode before calling this method.\n"; + "read(size) -- read from large object to sized string\n\n" + "Object must be opened in read mode before calling this method.\n"; static PyObject * large_read(largeObject *self, PyObject *args) @@ -200,11 +200,11 @@ large_read(largeObject *self, PyObject *args) } /* allocate buffer and runs read */ - buffer = PyBytes_FromStringAndSize((char *) NULL, size); + buffer = PyBytes_FromStringAndSize((char *)NULL, size); if ((size = lo_read(self->pgcnx->cnx, self->lo_fd, - PyBytes_AS_STRING((PyBytesObject *) (buffer)), (size_t) size)) == -1) - { + PyBytes_AS_STRING((PyBytesObject *)(buffer)), + (size_t)size)) == -1) { PyErr_SetString(PyExc_IOError, "Error while reading"); Py_XDECREF(buffer); return NULL; @@ -217,8 +217,8 @@ large_read(largeObject *self, PyObject *args) /* Write to large object. */ static char large_write__doc__[] = -"write(string) -- write sized string to large object\n\n" -"Object must be opened in read mode before calling this method.\n"; + "write(string) -- write sized string to large object\n\n" + "Object must be opened in read mode before calling this method.\n"; static PyObject * large_write(largeObject *self, PyObject *args) @@ -241,8 +241,7 @@ large_write(largeObject *self, PyObject *args) /* sends query */ if ((size = lo_write(self->pgcnx->cnx, self->lo_fd, buffer, - (size_t) bufsize)) != bufsize) - { + (size_t)bufsize)) != bufsize) { PyErr_SetString(PyExc_IOError, "Buffer truncated during write"); return NULL; } @@ -254,9 +253,9 @@ large_write(largeObject *self, PyObject *args) /* Go to position in large object. */ static char large_seek__doc__[] = -"seek(offset, whence) -- move to specified position\n\n" -"Object must be opened before calling this method. The whence option\n" -"can be SEEK_SET, SEEK_CUR or SEEK_END (module level constants).\n"; + "seek(offset, whence) -- move to specified position\n\n" + "Object must be opened before calling this method. The whence option\n" + "can be SEEK_SET, SEEK_CUR or SEEK_END (module level constants).\n"; static PyObject * large_seek(largeObject *self, PyObject *args) @@ -277,9 +276,8 @@ large_seek(largeObject *self, PyObject *args) } /* sends query */ - if ((ret = lo_lseek( - self->pgcnx->cnx, self->lo_fd, offset, whence)) == -1) - { + if ((ret = lo_lseek(self->pgcnx->cnx, self->lo_fd, offset, whence)) == + -1) { PyErr_SetString(PyExc_IOError, "Error while moving cursor"); return NULL; } @@ -290,8 +288,8 @@ large_seek(largeObject *self, PyObject *args) /* Get large object size. */ static char large_size__doc__[] = -"size() -- return large object size\n\n" -"The object must be opened before calling this method.\n"; + "size() -- return large object size\n\n" + "The object must be opened before calling this method.\n"; static PyObject * large_size(largeObject *self, PyObject *noargs) @@ -316,9 +314,8 @@ large_size(largeObject *self, PyObject *noargs) } /* move back to start position */ - if ((start = lo_lseek( - self->pgcnx->cnx, self->lo_fd, start, SEEK_SET)) == -1) - { + if ((start = lo_lseek(self->pgcnx->cnx, self->lo_fd, start, SEEK_SET)) == + -1) { PyErr_SetString(PyExc_IOError, "Error while moving back to first position"); return NULL; @@ -330,8 +327,8 @@ large_size(largeObject *self, PyObject *noargs) /* Get large object cursor position. */ static char large_tell__doc__[] = -"tell() -- give current position in large object\n\n" -"The object must be opened before calling this method.\n"; + "tell() -- give current position in large object\n\n" + "The object must be opened before calling this method.\n"; static PyObject * large_tell(largeObject *self, PyObject *noargs) @@ -355,8 +352,8 @@ large_tell(largeObject *self, PyObject *noargs) /* Export large object as unix file. */ static char large_export__doc__[] = -"export(filename) -- export large object data to specified file\n\n" -"The object must be closed when calling this method.\n"; + "export(filename) -- export large object data to specified file\n\n" + "The object must be closed when calling this method.\n"; static PyObject * large_export(largeObject *self, PyObject *args) @@ -387,8 +384,8 @@ large_export(largeObject *self, PyObject *args) /* Delete a large object. */ static char large_unlink__doc__[] = -"unlink() -- destroy large object\n\n" -"The object must be closed when calling this method.\n"; + "unlink() -- destroy large object\n\n" + "The object must be closed when calling this method.\n"; static PyObject * large_unlink(largeObject *self, PyObject *noargs) @@ -411,51 +408,49 @@ large_unlink(largeObject *self, PyObject *noargs) /* Large object methods */ static struct PyMethodDef large_methods[] = { - {"__dir__", (PyCFunction) large_dir, METH_NOARGS, NULL}, - {"open", (PyCFunction) large_open, METH_VARARGS, large_open__doc__}, - {"close", (PyCFunction) large_close, METH_NOARGS, large_close__doc__}, - {"read", (PyCFunction) large_read, METH_VARARGS, large_read__doc__}, - {"write", (PyCFunction) large_write, METH_VARARGS, large_write__doc__}, - {"seek", (PyCFunction) large_seek, METH_VARARGS, large_seek__doc__}, - {"size", (PyCFunction) large_size, METH_NOARGS, large_size__doc__}, - {"tell", (PyCFunction) large_tell, METH_NOARGS, large_tell__doc__}, - {"export",(PyCFunction) large_export, METH_VARARGS, large_export__doc__}, - {"unlink",(PyCFunction) large_unlink, METH_NOARGS, large_unlink__doc__}, - {NULL, NULL} -}; + {"__dir__", (PyCFunction)large_dir, METH_NOARGS, NULL}, + {"open", (PyCFunction)large_open, METH_VARARGS, large_open__doc__}, + {"close", (PyCFunction)large_close, METH_NOARGS, large_close__doc__}, + {"read", (PyCFunction)large_read, METH_VARARGS, large_read__doc__}, + {"write", (PyCFunction)large_write, METH_VARARGS, large_write__doc__}, + {"seek", (PyCFunction)large_seek, METH_VARARGS, large_seek__doc__}, + {"size", (PyCFunction)large_size, METH_NOARGS, large_size__doc__}, + {"tell", (PyCFunction)large_tell, METH_NOARGS, large_tell__doc__}, + {"export", (PyCFunction)large_export, METH_VARARGS, large_export__doc__}, + {"unlink", (PyCFunction)large_unlink, METH_NOARGS, large_unlink__doc__}, + {NULL, NULL}}; static char large__doc__[] = "PostgreSQL large object"; /* Large object type definition */ static PyTypeObject largeType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.LargeObject", /* tp_name */ - sizeof(largeObject), /* tp_basicsize */ - 0, /* tp_itemsize */ + PyVarObject_HEAD_INIT(NULL, 0) "pg.LargeObject", /* tp_name */ + sizeof(largeObject), /* tp_basicsize */ + 0, /* tp_itemsize */ /* methods */ - (destructor) large_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) large_str, /* tp_str */ - (getattrofunc) large_getattr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - large__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - large_methods, /* tp_methods */ + (destructor)large_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)large_str, /* tp_str */ + (getattrofunc)large_getattr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + large__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + large_methods, /* tp_methods */ }; diff --git a/pgmodule.c b/pgmodule.c index f1335263..628de9ec 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -12,7 +12,6 @@ #define PY_SSIZE_T_CLEAN #include - #include #include @@ -20,9 +19,9 @@ #include "pgtypes.h" static PyObject *Error, *Warning, *InterfaceError, *DatabaseError, - *InternalError, *OperationalError, *ProgrammingError, - *IntegrityError, *DataError, *NotSupportedError, - *InvalidResultError, *NoResultError, *MultipleResultsError; + *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, + *DataError, *NotSupportedError, *InvalidResultError, *NoResultError, + *MultipleResultsError; #define _TOSTRING(x) #x #define TOSTRING(x) _TOSTRING(x) @@ -36,23 +35,23 @@ static const char *PyPgVersion = TOSTRING(PYGRESQL_VERSION); #define PG_ARRAYSIZE 1 /* Flags for object validity checks */ -#define CHECK_OPEN 1 -#define CHECK_CLOSE 2 -#define CHECK_CNX 4 +#define CHECK_OPEN 1 +#define CHECK_CLOSE 2 +#define CHECK_CNX 4 #define CHECK_RESULT 8 -#define CHECK_DQL 16 +#define CHECK_DQL 16 /* Query result types */ #define RESULT_EMPTY 1 -#define RESULT_DML 2 -#define RESULT_DDL 3 -#define RESULT_DQL 4 +#define RESULT_DML 2 +#define RESULT_DDL 3 +#define RESULT_DQL 4 /* Flags for move methods */ #define QUERY_MOVEFIRST 1 -#define QUERY_MOVELAST 2 -#define QUERY_MOVENEXT 3 -#define QUERY_MOVEPREV 4 +#define QUERY_MOVELAST 2 +#define QUERY_MOVENEXT 3 +#define QUERY_MOVEPREV 4 #define MAX_BUFFER_SIZE 65536 /* maximum transaction size */ #define MAX_ARRAY_DEPTH 16 /* maximum allowed depth of an array */ @@ -67,16 +66,17 @@ static PyObject *pg_default_user; /* default username */ static PyObject *pg_default_passwd; /* default password */ static PyObject *decimal = NULL, /* decimal type */ - *dictiter = NULL, /* function for getting dict results */ - *namediter = NULL, /* function for getting named results */ - *namednext = NULL, /* function for getting one named result */ + *dictiter = NULL, /* function for getting dict results */ + *namediter = NULL, /* function for getting named results */ + *namednext = NULL, /* function for getting one named result */ *scalariter = NULL, /* function for getting scalar results */ - *jsondecode = NULL; /* function for decoding json strings */ + *jsondecode = + NULL; /* function for decoding json strings */ static const char *date_format = NULL; /* date format that is always assumed */ -static char decimal_point = '.'; /* decimal point used in money values */ -static int bool_as_text = 0; /* whether bool shall be returned as text */ -static int array_as_text = 0; /* whether arrays shall be returned as text */ -static int bytea_escaped = 0; /* whether bytea shall be returned escaped */ +static char decimal_point = '.'; /* decimal point used in money values */ +static int bool_as_text = 0; /* whether bool shall be returned as text */ +static int array_as_text = 0; /* whether arrays shall be returned as text */ +static int bytea_escaped = 0; /* whether bytea shall be returned escaped */ static int pg_encoding_utf8 = 0; static int pg_encoding_latin1 = 0; @@ -106,65 +106,56 @@ OBJECTS static PyTypeObject connType, sourceType, queryType, noticeType, largeType; /* Forward static declarations */ -static void notice_receiver(void *, const PGresult *); +static void +notice_receiver(void *, const PGresult *); /* Object declarations */ -typedef struct -{ - PyObject_HEAD - int valid; /* validity flag */ - PGconn *cnx; /* Postgres connection handle */ - const char *date_format; /* date format derived from datestyle */ - PyObject *cast_hook; /* external typecast method */ - PyObject *notice_receiver; /* current notice receiver */ -} connObject; +typedef struct { + PyObject_HEAD int valid; /* validity flag */ + PGconn *cnx; /* Postgres connection handle */ + const char *date_format; /* date format derived from datestyle */ + PyObject *cast_hook; /* external typecast method */ + PyObject *notice_receiver; /* current notice receiver */ +} connObject; #define is_connObject(v) (PyType(v) == &connType) -typedef struct -{ - PyObject_HEAD - int valid; /* validity flag */ +typedef struct { + PyObject_HEAD int valid; /* validity flag */ connObject *pgcnx; /* parent connection object */ - PGresult *result; /* result content */ - int encoding; /* client encoding */ - int result_type; /* result type (DDL/DML/DQL) */ - long arraysize; /* array size for fetch method */ - int current_row; /* currently selected row */ - int max_row; /* number of rows in the result */ - int num_fields; /* number of fields in each row */ -} sourceObject; + PGresult *result; /* result content */ + int encoding; /* client encoding */ + int result_type; /* result type (DDL/DML/DQL) */ + long arraysize; /* array size for fetch method */ + int current_row; /* currently selected row */ + int max_row; /* number of rows in the result */ + int num_fields; /* number of fields in each row */ +} sourceObject; #define is_sourceObject(v) (PyType(v) == &sourceType) -typedef struct -{ - PyObject_HEAD - connObject *pgcnx; /* parent connection object */ - PGresult const *res; /* an error or warning */ -} noticeObject; +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + PGresult const *res; /* an error or warning */ +} noticeObject; #define is_noticeObject(v) (PyType(v) == ¬iceType) -typedef struct -{ - PyObject_HEAD - connObject *pgcnx; /* parent connection object */ - PGresult *result; /* result content */ - int async; /* flag for asynchronous queries */ - int encoding; /* client encoding */ - int current_row; /* currently selected row */ - int max_row; /* number of rows in the result */ - int num_fields; /* number of fields in each row */ - int *col_types; /* PyGreSQL column types */ -} queryObject; +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + PGresult *result; /* result content */ + int async; /* flag for asynchronous queries */ + int encoding; /* client encoding */ + int current_row; /* currently selected row */ + int max_row; /* number of rows in the result */ + int num_fields; /* number of fields in each row */ + int *col_types; /* PyGreSQL column types */ +} queryObject; #define is_queryObject(v) (PyType(v) == &queryType) -typedef struct -{ - PyObject_HEAD - connObject *pgcnx; /* parent connection object */ - Oid lo_oid; /* large object oid */ - int lo_fd; /* large object fd */ -} largeObject; +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + Oid lo_oid; /* large object oid */ + int lo_fd; /* large object fd */ +} largeObject; #define is_largeObject(v) (PyType(v) == &largeType) /* Internal functions */ @@ -189,22 +180,22 @@ typedef struct /* Connect to a database. */ static char pg_connect__doc__[] = -"connect(dbname, host, port, opt, user, passwd, wait) -- connect to a PostgreSQL database\n\n" -"The connection uses the specified parameters (optional, keywords aware).\n"; + "connect(dbname, host, port, opt, user, passwd, wait) -- connect to a " + "PostgreSQL database\n\n" + "The connection uses the specified parameters (optional, keywords " + "aware).\n"; static PyObject * pg_connect(PyObject *self, PyObject *args, PyObject *dict) { - static const char *kwlist[] = - { - "dbname", "host", "port", "opt", "user", "passwd", "nowait", NULL - }; + static const char *kwlist[] = {"dbname", "host", "port", "opt", + "user", "passwd", "nowait", NULL}; char *pghost, *pgopt, *pgdbname, *pguser, *pgpasswd; int pgport = -1, nowait = 0, nkw = 0; char port_buffer[20]; const char *keywords[sizeof(kwlist) / sizeof(*kwlist) + 1], - *values[sizeof(kwlist) / sizeof(*kwlist) + 1]; + *values[sizeof(kwlist) / sizeof(*kwlist) + 1]; connObject *conn_obj; pghost = pgopt = pgdbname = pguser = pgpasswd = NULL; @@ -215,10 +206,9 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) * don't declare kwlist as const char *kwlist[] then it complains when * I try to assign all those constant strings to it. */ - if (!PyArg_ParseTupleAndKeywords( - args, dict, "|zzizzzi", (char**)kwlist, - &pgdbname, &pghost, &pgport, &pgopt, &pguser, &pgpasswd, &nowait)) - { + if (!PyArg_ParseTupleAndKeywords(args, dict, "|zzizzzi", (char **)kwlist, + &pgdbname, &pghost, &pgport, &pgopt, + &pguser, &pgpasswd, &nowait)) { return NULL; } @@ -227,7 +217,7 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) pghost = PyBytes_AsString(pg_default_host); if ((pgport == -1) && (pg_default_port != Py_None)) - pgport = (int) PyLong_AsLong(pg_default_port); + pgport = (int)PyLong_AsLong(pg_default_port); if ((!pgopt) && (pg_default_opt != Py_None)) pgopt = PyBytes_AsString(pg_default_opt); @@ -252,33 +242,27 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) conn_obj->cast_hook = NULL; conn_obj->notice_receiver = NULL; - if (pghost) - { + if (pghost) { keywords[nkw] = "host"; values[nkw++] = pghost; } - if (pgopt) - { + if (pgopt) { keywords[nkw] = "options"; values[nkw++] = pgopt; } - if (pgdbname) - { + if (pgdbname) { keywords[nkw] = "dbname"; values[nkw++] = pgdbname; } - if (pguser) - { + if (pguser) { keywords[nkw] = "user"; values[nkw++] = pguser; } - if (pgpasswd) - { + if (pgpasswd) { keywords[nkw] = "password"; values[nkw++] = pgpasswd; } - if (pgport != -1) - { + if (pgport != -1) { memset(port_buffer, 0, sizeof(port_buffer)); sprintf(port_buffer, "%d", pgport); @@ -288,8 +272,8 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) keywords[nkw] = values[nkw] = NULL; Py_BEGIN_ALLOW_THREADS - conn_obj->cnx = nowait ? PQconnectStartParams(keywords, values, 1) : - PQconnectdbParams(keywords, values, 1); + conn_obj->cnx = nowait ? PQconnectStartParams(keywords, values, 1) + : PQconnectdbParams(keywords, values, 1); Py_END_ALLOW_THREADS if (PQstatus(conn_obj->cnx) == CONNECTION_BAD) { @@ -298,32 +282,33 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) return NULL; } - return (PyObject *) conn_obj; + return (PyObject *)conn_obj; } /* Get version of libpq that is being used */ static char pg_get_pqlib_version__doc__[] = -"get_pqlib_version() -- get the version of libpq that is being used"; + "get_pqlib_version() -- get the version of libpq that is being used"; static PyObject * -pg_get_pqlib_version(PyObject *self, PyObject *noargs) { +pg_get_pqlib_version(PyObject *self, PyObject *noargs) +{ return PyLong_FromLong(PQlibVersion()); } /* Escape string */ static char pg_escape_string__doc__[] = -"escape_string(string) -- escape a string for use within SQL"; + "escape_string(string) -- escape a string for use within SQL"; static PyObject * pg_escape_string(PyObject *self, PyObject *string) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(string)) { PyBytes_AsStringAndSize(string, &from, &from_length); @@ -331,7 +316,8 @@ pg_escape_string(PyObject *self, PyObject *string) else if (PyUnicode_Check(string)) { encoding = pg_encoding_ascii; tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -340,38 +326,39 @@ pg_escape_string(PyObject *self, PyObject *string) return NULL; } - to_length = 2 * (size_t) from_length + 1; - if ((Py_ssize_t ) to_length < from_length) { /* overflow */ - to_length = (size_t) from_length; - from_length = (from_length - 1)/2; + to_length = 2 * (size_t)from_length + 1; + if ((Py_ssize_t)to_length < from_length) { /* overflow */ + to_length = (size_t)from_length; + from_length = (from_length - 1) / 2; } - to = (char *) PyMem_Malloc(to_length); - to_length = (size_t) PQescapeString(to, from, (size_t) from_length); + to = (char *)PyMem_Malloc(to_length); + to_length = (size_t)PQescapeString(to, from, (size_t)from_length); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); PyMem_Free(to); return to_obj; } /* Escape bytea */ static char pg_escape_bytea__doc__[] = -"escape_bytea(data) -- escape binary data for use within SQL as type bytea"; + "escape_bytea(data) -- escape binary data for use within SQL as type " + "bytea"; static PyObject * pg_escape_bytea(PyObject *self, PyObject *data) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(data)) { PyBytes_AsStringAndSize(data, &from, &from_length); @@ -379,7 +366,8 @@ pg_escape_bytea(PyObject *self, PyObject *data) else if (PyUnicode_Check(data)) { encoding = pg_encoding_ascii; tmp_obj = get_encoded_string(data, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -388,15 +376,15 @@ pg_escape_bytea(PyObject *self, PyObject *data) return NULL; } - to = (char *) PQescapeBytea( - (unsigned char*) from, (size_t) from_length, &to_length); + to = (char *)PQescapeBytea((unsigned char *)from, (size_t)from_length, + &to_length); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length - 1); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length - 1); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length - 1, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length - 1, encoding); if (to) PQfreemem(to); return to_obj; @@ -404,24 +392,25 @@ pg_escape_bytea(PyObject *self, PyObject *data) /* Unescape bytea */ static char pg_unescape_bytea__doc__[] = -"unescape_bytea(string) -- unescape bytea data retrieved as text"; + "unescape_bytea(string) -- unescape bytea data retrieved as text"; static PyObject * pg_unescape_bytea(PyObject *self, PyObject *data) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ if (PyBytes_Check(data)) { PyBytes_AsStringAndSize(data, &from, &from_length); } else if (PyUnicode_Check(data)) { tmp_obj = get_encoded_string(data, pg_encoding_ascii); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -431,13 +420,14 @@ pg_unescape_bytea(PyObject *self, PyObject *data) return NULL; } - to = (char *) PQunescapeBytea((unsigned char*) from, &to_length); + to = (char *)PQunescapeBytea((unsigned char *)from, &to_length); Py_XDECREF(tmp_obj); - if (!to) return PyErr_NoMemory(); + if (!to) + return PyErr_NoMemory(); - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); PQfreemem(to); return to_obj; @@ -445,7 +435,7 @@ pg_unescape_bytea(PyObject *self, PyObject *data) /* Set fixed datestyle. */ static char pg_set_datestyle__doc__[] = -"set_datestyle(style) -- set which style is assumed"; + "set_datestyle(style) -- set which style is assumed"; static PyObject * pg_set_datestyle(PyObject *self, PyObject *args) @@ -462,12 +452,13 @@ pg_set_datestyle(PyObject *self, PyObject *args) date_format = datestyle ? date_style_to_format(datestyle) : NULL; - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } /* Get fixed datestyle. */ static char pg_get_datestyle__doc__[] = -"get_datestyle() -- get which date style is assumed"; + "get_datestyle() -- get which date style is assumed"; static PyObject * pg_get_datestyle(PyObject *self, PyObject *noargs) @@ -476,13 +467,14 @@ pg_get_datestyle(PyObject *self, PyObject *noargs) return PyUnicode_FromString(date_format_to_style(date_format)); } else { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } } /* Get decimal point. */ static char pg_get_decimal_point__doc__[] = -"get_decimal_point() -- get decimal point to be used for money values"; + "get_decimal_point() -- get decimal point to be used for money values"; static PyObject * pg_get_decimal_point(PyObject *self, PyObject *noargs) @@ -491,11 +483,13 @@ pg_get_decimal_point(PyObject *self, PyObject *noargs) char s[2]; if (decimal_point) { - s[0] = decimal_point; s[1] = '\0'; + s[0] = decimal_point; + s[1] = '\0'; ret = PyUnicode_FromString(s); } else { - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } return ret; @@ -503,7 +497,7 @@ pg_get_decimal_point(PyObject *self, PyObject *noargs) /* Set decimal point. */ static char pg_set_decimal_point__doc__[] = -"set_decimal_point(char) -- set decimal point to be used for money values"; + "set_decimal_point(char) -- set decimal point to be used for money values"; static PyObject * pg_set_decimal_point(PyObject *self, PyObject *args) @@ -515,13 +509,14 @@ pg_set_decimal_point(PyObject *self, PyObject *args) if (PyArg_ParseTuple(args, "z", &s)) { if (!s) s = "\0"; - else if (*s && (*(s+1) || !strchr(".,;: '*/_`|", *s))) + else if (*s && (*(s + 1) || !strchr(".,;: '*/_`|", *s))) s = NULL; } if (s) { decimal_point = *s; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -533,7 +528,7 @@ pg_set_decimal_point(PyObject *self, PyObject *args) /* Get decimal type. */ static char pg_get_decimal__doc__[] = -"get_decimal() -- get the decimal type to be used for numeric values"; + "get_decimal() -- get the decimal type to be used for numeric values"; static PyObject * pg_get_decimal(PyObject *self, PyObject *noargs) @@ -548,7 +543,7 @@ pg_get_decimal(PyObject *self, PyObject *noargs) /* Set decimal type. */ static char pg_set_decimal__doc__[] = -"set_decimal(cls) -- set a decimal type to be used for numeric values"; + "set_decimal(cls) -- set a decimal type to be used for numeric values"; static PyObject * pg_set_decimal(PyObject *self, PyObject *cls) @@ -556,12 +551,17 @@ pg_set_decimal(PyObject *self, PyObject *cls) PyObject *ret = NULL; if (cls == Py_None) { - Py_XDECREF(decimal); decimal = NULL; - Py_INCREF(Py_None); ret = Py_None; + Py_XDECREF(decimal); + decimal = NULL; + Py_INCREF(Py_None); + ret = Py_None; } else if (PyCallable_Check(cls)) { - Py_XINCREF(cls); Py_XDECREF(decimal); decimal = cls; - Py_INCREF(Py_None); ret = Py_None; + Py_XINCREF(cls); + Py_XDECREF(decimal); + decimal = cls; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -574,7 +574,7 @@ pg_set_decimal(PyObject *self, PyObject *cls) /* Get usage of bool values. */ static char pg_get_bool__doc__[] = -"get_bool() -- check whether boolean values are converted to bool"; + "get_bool() -- check whether boolean values are converted to bool"; static PyObject * pg_get_bool(PyObject *self, PyObject *noargs) @@ -589,7 +589,7 @@ pg_get_bool(PyObject *self, PyObject *noargs) /* Set usage of bool values. */ static char pg_set_bool__doc__[] = -"set_bool(on) -- set whether boolean values should be converted to bool"; + "set_bool(on) -- set whether boolean values should be converted to bool"; static PyObject * pg_set_bool(PyObject *self, PyObject *args) @@ -600,7 +600,8 @@ pg_set_bool(PyObject *self, PyObject *args) /* gets arguments */ if (PyArg_ParseTuple(args, "i", &i)) { bool_as_text = i ? 0 : 1; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString( @@ -613,7 +614,7 @@ pg_set_bool(PyObject *self, PyObject *args) /* Get conversion of arrays to lists. */ static char pg_get_array__doc__[] = -"get_array() -- check whether arrays are converted as lists"; + "get_array() -- check whether arrays are converted as lists"; static PyObject * pg_get_array(PyObject *self, PyObject *noargs) @@ -628,18 +629,19 @@ pg_get_array(PyObject *self, PyObject *noargs) /* Set conversion of arrays to lists. */ static char pg_set_array__doc__[] = -"set_array(on) -- set whether arrays should be converted to lists"; + "set_array(on) -- set whether arrays should be converted to lists"; static PyObject * -pg_set_array(PyObject* self, PyObject* args) +pg_set_array(PyObject *self, PyObject *args) { - PyObject* ret = NULL; + PyObject *ret = NULL; int i; /* gets arguments */ if (PyArg_ParseTuple(args, "i", &i)) { array_as_text = i ? 0 : 1; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString( @@ -652,7 +654,7 @@ pg_set_array(PyObject* self, PyObject* args) /* Check whether bytea values are unescaped. */ static char pg_get_bytea_escaped__doc__[] = -"get_bytea_escaped() -- check whether bytea will be returned escaped"; + "get_bytea_escaped() -- check whether bytea will be returned escaped"; static PyObject * pg_get_bytea_escaped(PyObject *self, PyObject *noargs) @@ -667,7 +669,7 @@ pg_get_bytea_escaped(PyObject *self, PyObject *noargs) /* Set usage of bool values. */ static char pg_set_bytea_escaped__doc__[] = -"set_bytea_escaped(on) -- set whether bytea will be returned escaped"; + "set_bytea_escaped(on) -- set whether bytea will be returned escaped"; static PyObject * pg_set_bytea_escaped(PyObject *self, PyObject *args) @@ -678,7 +680,8 @@ pg_set_bytea_escaped(PyObject *self, PyObject *args) /* gets arguments */ if (PyArg_ParseTuple(args, "i", &i)) { bytea_escaped = i ? 1 : 0; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -692,18 +695,15 @@ pg_set_bytea_escaped(PyObject *self, PyObject *args) /* set query helper functions (not part of public API) */ static char pg_set_query_helpers__doc__[] = -"set_query_helpers(*helpers) -- set internal query helper functions"; + "set_query_helpers(*helpers) -- set internal query helper functions"; static PyObject * pg_set_query_helpers(PyObject *self, PyObject *args) { /* gets arguments */ - if (!PyArg_ParseTuple(args, "O!O!O!O!", - &PyFunction_Type, &dictiter, - &PyFunction_Type, &namediter, - &PyFunction_Type, &namednext, - &PyFunction_Type, &scalariter)) - { + if (!PyArg_ParseTuple(args, "O!O!O!O!", &PyFunction_Type, &dictiter, + &PyFunction_Type, &namediter, &PyFunction_Type, + &namednext, &PyFunction_Type, &scalariter)) { return NULL; } @@ -713,7 +713,7 @@ pg_set_query_helpers(PyObject *self, PyObject *args) /* Get json decode function. */ static char pg_get_jsondecode__doc__[] = -"get_jsondecode() -- get the function used for decoding json results"; + "get_jsondecode() -- get the function used for decoding json results"; static PyObject * pg_get_jsondecode(PyObject *self, PyObject *noargs) @@ -730,7 +730,8 @@ pg_get_jsondecode(PyObject *self, PyObject *noargs) /* Set json decode function. */ static char pg_set_jsondecode__doc__[] = -"set_jsondecode(func) -- set a function to be used for decoding json results"; + "set_jsondecode(func) -- set a function to be used for decoding json " + "results"; static PyObject * pg_set_jsondecode(PyObject *self, PyObject *func) @@ -738,12 +739,17 @@ pg_set_jsondecode(PyObject *self, PyObject *func) PyObject *ret = NULL; if (func == Py_None) { - Py_XDECREF(jsondecode); jsondecode = NULL; - Py_INCREF(Py_None); ret = Py_None; + Py_XDECREF(jsondecode); + jsondecode = NULL; + Py_INCREF(Py_None); + ret = Py_None; } else if (PyCallable_Check(func)) { - Py_XINCREF(func); Py_XDECREF(jsondecode); jsondecode = func; - Py_INCREF(Py_None); ret = Py_None; + Py_XINCREF(func); + Py_XDECREF(jsondecode); + jsondecode = func; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -756,7 +762,7 @@ pg_set_jsondecode(PyObject *self, PyObject *func) /* Get default host. */ static char pg_get_defhost__doc__[] = -"get_defhost() -- return default database host"; + "get_defhost() -- return default database host"; static PyObject * pg_get_defhost(PyObject *self, PyObject *noargs) @@ -767,7 +773,8 @@ pg_get_defhost(PyObject *self, PyObject *noargs) /* Set default host. */ static char pg_set_defhost__doc__[] = -"set_defhost(string) -- set default database host and return previous value"; + "set_defhost(string) -- set default database host and return previous " + "value"; static PyObject * pg_set_defhost(PyObject *self, PyObject *args) @@ -799,7 +806,7 @@ pg_set_defhost(PyObject *self, PyObject *args) /* Get default database. */ static char pg_get_defbase__doc__[] = -"get_defbase() -- return default database name"; + "get_defbase() -- return default database name"; static PyObject * pg_get_defbase(PyObject *self, PyObject *noargs) @@ -810,7 +817,8 @@ pg_get_defbase(PyObject *self, PyObject *noargs) /* Set default database. */ static char pg_set_defbase__doc__[] = -"set_defbase(string) -- set default database name and return previous value"; + "set_defbase(string) -- set default database name and return previous " + "value"; static PyObject * pg_set_defbase(PyObject *self, PyObject *args) @@ -842,7 +850,7 @@ pg_set_defbase(PyObject *self, PyObject *args) /* Get default options. */ static char pg_get_defopt__doc__[] = -"get_defopt() -- return default database options"; + "get_defopt() -- return default database options"; static PyObject * pg_get_defopt(PyObject *self, PyObject *noargs) @@ -853,7 +861,7 @@ pg_get_defopt(PyObject *self, PyObject *noargs) /* Set default options. */ static char pg_set_defopt__doc__[] = -"set_defopt(string) -- set default options and return previous value"; + "set_defopt(string) -- set default options and return previous value"; static PyObject * pg_setdefopt(PyObject *self, PyObject *args) @@ -885,7 +893,7 @@ pg_setdefopt(PyObject *self, PyObject *args) /* Get default username. */ static char pg_get_defuser__doc__[] = -"get_defuser() -- return default database username"; + "get_defuser() -- return default database username"; static PyObject * pg_get_defuser(PyObject *self, PyObject *noargs) @@ -897,7 +905,7 @@ pg_get_defuser(PyObject *self, PyObject *noargs) /* Set default username. */ static char pg_set_defuser__doc__[] = -"set_defuser(name) -- set default username and return previous value"; + "set_defuser(name) -- set default username and return previous value"; static PyObject * pg_set_defuser(PyObject *self, PyObject *args) @@ -929,7 +937,7 @@ pg_set_defuser(PyObject *self, PyObject *args) /* Set default password. */ static char pg_set_defpasswd__doc__[] = -"set_defpasswd(password) -- set default database password"; + "set_defpasswd(password) -- set default database password"; static PyObject * pg_set_defpasswd(PyObject *self, PyObject *args) @@ -958,7 +966,7 @@ pg_set_defpasswd(PyObject *self, PyObject *args) /* Get default port. */ static char pg_get_defport__doc__[] = -"get_defport() -- return default database port"; + "get_defport() -- return default database port"; static PyObject * pg_get_defport(PyObject *self, PyObject *noargs) @@ -969,7 +977,7 @@ pg_get_defport(PyObject *self, PyObject *noargs) /* Set default port. */ static char pg_set_defport__doc__[] = -"set_defport(port) -- set default port and return previous value"; + "set_defport(port) -- set default port and return previous value"; static PyObject * pg_set_defport(PyObject *self, PyObject *args) @@ -1001,7 +1009,7 @@ pg_set_defport(PyObject *self, PyObject *args) /* Cast a string with a text representation of an array to a list. */ static char pg_cast_array__doc__[] = -"cast_array(string, cast=None, delim=',') -- cast a string as an array"; + "cast_array(string, cast=None, delim=',') -- cast a string as an array"; PyObject * pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) @@ -1012,10 +1020,8 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) Py_ssize_t size; int encoding; - if (!PyArg_ParseTupleAndKeywords( - args, dict, "O|Oc", - (char**) kwlist, &string_obj, &cast_obj, &delim)) - { + if (!PyArg_ParseTupleAndKeywords(args, dict, "O|Oc", (char **)kwlist, + &string_obj, &cast_obj, &delim)) { return NULL; } @@ -1026,7 +1032,8 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) } else if (PyUnicode_Check(string_obj)) { string_obj = PyUnicode_AsUTF8String(string_obj); - if (!string_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!string_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(string_obj, &string, &size); encoding = pg_encoding_utf8; } @@ -1056,7 +1063,7 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) /* Cast a string with a text representation of a record to a tuple. */ static char pg_cast_record__doc__[] = -"cast_record(string, cast=None, delim=',') -- cast a string as a record"; + "cast_record(string, cast=None, delim=',') -- cast a string as a record"; PyObject * pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) @@ -1067,10 +1074,8 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) Py_ssize_t size, len; int encoding; - if (!PyArg_ParseTupleAndKeywords( - args, dict, "O|Oc", - (char**) kwlist, &string_obj, &cast_obj, &delim)) - { + if (!PyArg_ParseTupleAndKeywords(args, dict, "O|Oc", (char **)kwlist, + &string_obj, &cast_obj, &delim)) { return NULL; } @@ -1081,7 +1086,8 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) } else if (PyUnicode_Check(string_obj)) { string_obj = PyUnicode_AsUTF8String(string_obj); - if (!string_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!string_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(string_obj, &string, &size); encoding = pg_encoding_utf8; } @@ -1096,7 +1102,8 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) len = 0; } else if (cast_obj == Py_None) { - cast_obj = NULL; len = 0; + cast_obj = NULL; + len = 0; } else if (PyTuple_Check(cast_obj) || PyList_Check(cast_obj)) { len = PySequence_Size(cast_obj); @@ -1120,7 +1127,7 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) /* Cast a string with a text representation of an hstore to a dict. */ static char pg_cast_hstore__doc__[] = -"cast_hstore(string) -- cast a string as an hstore"; + "cast_hstore(string) -- cast a string as an hstore"; PyObject * pg_cast_hstore(PyObject *self, PyObject *string) @@ -1136,7 +1143,8 @@ pg_cast_hstore(PyObject *self, PyObject *string) } else if (PyUnicode_Check(string)) { tmp_obj = PyUnicode_AsUTF8String(string); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &s, &size); encoding = pg_encoding_utf8; } @@ -1157,50 +1165,47 @@ pg_cast_hstore(PyObject *self, PyObject *string) /* The list of functions defined in the module */ static struct PyMethodDef pg_methods[] = { - {"connect", (PyCFunction) pg_connect, - METH_VARARGS|METH_KEYWORDS, pg_connect__doc__}, - {"escape_string", (PyCFunction) pg_escape_string, - METH_O, pg_escape_string__doc__}, - {"escape_bytea", (PyCFunction) pg_escape_bytea, - METH_O, pg_escape_bytea__doc__}, - {"unescape_bytea", (PyCFunction) pg_unescape_bytea, - METH_O, pg_unescape_bytea__doc__}, - {"get_datestyle", (PyCFunction) pg_get_datestyle, - METH_NOARGS, pg_get_datestyle__doc__}, - {"set_datestyle", (PyCFunction) pg_set_datestyle, - METH_VARARGS, pg_set_datestyle__doc__}, - {"get_decimal_point", (PyCFunction) pg_get_decimal_point, - METH_NOARGS, pg_get_decimal_point__doc__}, - {"set_decimal_point", (PyCFunction) pg_set_decimal_point, - METH_VARARGS, pg_set_decimal_point__doc__}, - {"get_decimal", (PyCFunction) pg_get_decimal, - METH_NOARGS, pg_get_decimal__doc__}, - {"set_decimal", (PyCFunction) pg_set_decimal, - METH_O, pg_set_decimal__doc__}, - {"get_bool", (PyCFunction) pg_get_bool, - METH_NOARGS, pg_get_bool__doc__}, - {"set_bool", (PyCFunction) pg_set_bool, - METH_VARARGS, pg_set_bool__doc__}, - {"get_array", (PyCFunction) pg_get_array, - METH_NOARGS, pg_get_array__doc__}, - {"set_array", (PyCFunction) pg_set_array, - METH_VARARGS, pg_set_array__doc__}, - {"set_query_helpers", (PyCFunction) pg_set_query_helpers, - METH_VARARGS, pg_set_query_helpers__doc__}, - {"get_bytea_escaped", (PyCFunction) pg_get_bytea_escaped, - METH_NOARGS, pg_get_bytea_escaped__doc__}, - {"set_bytea_escaped", (PyCFunction) pg_set_bytea_escaped, - METH_VARARGS, pg_set_bytea_escaped__doc__}, - {"get_jsondecode", (PyCFunction) pg_get_jsondecode, - METH_NOARGS, pg_get_jsondecode__doc__}, - {"set_jsondecode", (PyCFunction) pg_set_jsondecode, - METH_O, pg_set_jsondecode__doc__}, - {"cast_array", (PyCFunction) pg_cast_array, - METH_VARARGS|METH_KEYWORDS, pg_cast_array__doc__}, - {"cast_record", (PyCFunction) pg_cast_record, - METH_VARARGS|METH_KEYWORDS, pg_cast_record__doc__}, - {"cast_hstore", (PyCFunction) pg_cast_hstore, - METH_O, pg_cast_hstore__doc__}, + {"connect", (PyCFunction)pg_connect, METH_VARARGS | METH_KEYWORDS, + pg_connect__doc__}, + {"escape_string", (PyCFunction)pg_escape_string, METH_O, + pg_escape_string__doc__}, + {"escape_bytea", (PyCFunction)pg_escape_bytea, METH_O, + pg_escape_bytea__doc__}, + {"unescape_bytea", (PyCFunction)pg_unescape_bytea, METH_O, + pg_unescape_bytea__doc__}, + {"get_datestyle", (PyCFunction)pg_get_datestyle, METH_NOARGS, + pg_get_datestyle__doc__}, + {"set_datestyle", (PyCFunction)pg_set_datestyle, METH_VARARGS, + pg_set_datestyle__doc__}, + {"get_decimal_point", (PyCFunction)pg_get_decimal_point, METH_NOARGS, + pg_get_decimal_point__doc__}, + {"set_decimal_point", (PyCFunction)pg_set_decimal_point, METH_VARARGS, + pg_set_decimal_point__doc__}, + {"get_decimal", (PyCFunction)pg_get_decimal, METH_NOARGS, + pg_get_decimal__doc__}, + {"set_decimal", (PyCFunction)pg_set_decimal, METH_O, + pg_set_decimal__doc__}, + {"get_bool", (PyCFunction)pg_get_bool, METH_NOARGS, pg_get_bool__doc__}, + {"set_bool", (PyCFunction)pg_set_bool, METH_VARARGS, pg_set_bool__doc__}, + {"get_array", (PyCFunction)pg_get_array, METH_NOARGS, pg_get_array__doc__}, + {"set_array", (PyCFunction)pg_set_array, METH_VARARGS, + pg_set_array__doc__}, + {"set_query_helpers", (PyCFunction)pg_set_query_helpers, METH_VARARGS, + pg_set_query_helpers__doc__}, + {"get_bytea_escaped", (PyCFunction)pg_get_bytea_escaped, METH_NOARGS, + pg_get_bytea_escaped__doc__}, + {"set_bytea_escaped", (PyCFunction)pg_set_bytea_escaped, METH_VARARGS, + pg_set_bytea_escaped__doc__}, + {"get_jsondecode", (PyCFunction)pg_get_jsondecode, METH_NOARGS, + pg_get_jsondecode__doc__}, + {"set_jsondecode", (PyCFunction)pg_set_jsondecode, METH_O, + pg_set_jsondecode__doc__}, + {"cast_array", (PyCFunction)pg_cast_array, METH_VARARGS | METH_KEYWORDS, + pg_cast_array__doc__}, + {"cast_record", (PyCFunction)pg_cast_record, METH_VARARGS | METH_KEYWORDS, + pg_cast_record__doc__}, + {"cast_hstore", (PyCFunction)pg_cast_hstore, METH_O, + pg_cast_hstore__doc__}, {"get_defhost", pg_get_defhost, METH_NOARGS, pg_get_defhost__doc__}, {"set_defhost", pg_set_defhost, METH_VARARGS, pg_set_defhost__doc__}, {"get_defbase", pg_get_defbase, METH_NOARGS, pg_get_defbase__doc__}, @@ -1212,25 +1217,26 @@ static struct PyMethodDef pg_methods[] = { {"get_defuser", pg_get_defuser, METH_NOARGS, pg_get_defuser__doc__}, {"set_defuser", pg_set_defuser, METH_VARARGS, pg_set_defuser__doc__}, {"set_defpasswd", pg_set_defpasswd, METH_VARARGS, pg_set_defpasswd__doc__}, - {"get_pqlib_version", (PyCFunction) pg_get_pqlib_version, - METH_NOARGS, pg_get_pqlib_version__doc__}, + {"get_pqlib_version", (PyCFunction)pg_get_pqlib_version, METH_NOARGS, + pg_get_pqlib_version__doc__}, {NULL, NULL} /* sentinel */ }; static char pg__doc__[] = "Python interface to PostgreSQL DB"; static struct PyModuleDef moduleDef = { - PyModuleDef_HEAD_INIT, - "_pg", /* m_name */ - pg__doc__, /* m_doc */ - -1, /* m_size */ - pg_methods /* m_methods */ + PyModuleDef_HEAD_INIT, "_pg", /* m_name */ + pg__doc__, /* m_doc */ + -1, /* m_size */ + pg_methods /* m_methods */ }; /* Initialization function for the module */ -PyMODINIT_FUNC PyInit__pg(void); +PyMODINIT_FUNC +PyInit__pg(void); -PyMODINIT_FUNC PyInit__pg(void) +PyMODINIT_FUNC +PyInit__pg(void) { PyObject *mod, *dict, *s; @@ -1239,17 +1245,13 @@ PyMODINIT_FUNC PyInit__pg(void) mod = PyModule_Create(&moduleDef); /* Initialize here because some Windows platforms get confused otherwise */ - connType.tp_base = noticeType.tp_base = - queryType.tp_base = sourceType.tp_base = &PyBaseObject_Type; + connType.tp_base = noticeType.tp_base = queryType.tp_base = + sourceType.tp_base = &PyBaseObject_Type; largeType.tp_base = &PyBaseObject_Type; - if (PyType_Ready(&connType) - || PyType_Ready(¬iceType) - || PyType_Ready(&queryType) - || PyType_Ready(&sourceType) - || PyType_Ready(&largeType) - ) - { + if (PyType_Ready(&connType) || PyType_Ready(¬iceType) || + PyType_Ready(&queryType) || PyType_Ready(&sourceType) || + PyType_Ready(&largeType)) { return NULL; } @@ -1262,48 +1264,45 @@ PyMODINIT_FUNC PyInit__pg(void) Warning = PyErr_NewException("pg.Warning", PyExc_Exception, NULL); PyDict_SetItemString(dict, "Warning", Warning); - InterfaceError = PyErr_NewException( - "pg.InterfaceError", Error, NULL); + InterfaceError = PyErr_NewException("pg.InterfaceError", Error, NULL); PyDict_SetItemString(dict, "InterfaceError", InterfaceError); - DatabaseError = PyErr_NewException( - "pg.DatabaseError", Error, NULL); + DatabaseError = PyErr_NewException("pg.DatabaseError", Error, NULL); PyDict_SetItemString(dict, "DatabaseError", DatabaseError); - InternalError = PyErr_NewException( - "pg.InternalError", DatabaseError, NULL); + InternalError = + PyErr_NewException("pg.InternalError", DatabaseError, NULL); PyDict_SetItemString(dict, "InternalError", InternalError); - OperationalError = PyErr_NewException( - "pg.OperationalError", DatabaseError, NULL); + OperationalError = + PyErr_NewException("pg.OperationalError", DatabaseError, NULL); PyDict_SetItemString(dict, "OperationalError", OperationalError); - ProgrammingError = PyErr_NewException( - "pg.ProgrammingError", DatabaseError, NULL); + ProgrammingError = + PyErr_NewException("pg.ProgrammingError", DatabaseError, NULL); PyDict_SetItemString(dict, "ProgrammingError", ProgrammingError); - IntegrityError = PyErr_NewException( - "pg.IntegrityError", DatabaseError, NULL); + IntegrityError = + PyErr_NewException("pg.IntegrityError", DatabaseError, NULL); PyDict_SetItemString(dict, "IntegrityError", IntegrityError); - DataError = PyErr_NewException( - "pg.DataError", DatabaseError, NULL); + DataError = PyErr_NewException("pg.DataError", DatabaseError, NULL); PyDict_SetItemString(dict, "DataError", DataError); - NotSupportedError = PyErr_NewException( - "pg.NotSupportedError", DatabaseError, NULL); + NotSupportedError = + PyErr_NewException("pg.NotSupportedError", DatabaseError, NULL); PyDict_SetItemString(dict, "NotSupportedError", NotSupportedError); - InvalidResultError = PyErr_NewException( - "pg.InvalidResultError", DataError, NULL); + InvalidResultError = + PyErr_NewException("pg.InvalidResultError", DataError, NULL); PyDict_SetItemString(dict, "InvalidResultError", InvalidResultError); - NoResultError = PyErr_NewException( - "pg.NoResultError", InvalidResultError, NULL); + NoResultError = + PyErr_NewException("pg.NoResultError", InvalidResultError, NULL); PyDict_SetItemString(dict, "NoResultError", NoResultError); - MultipleResultsError = PyErr_NewException( - "pg.MultipleResultsError", InvalidResultError, NULL); + MultipleResultsError = PyErr_NewException("pg.MultipleResultsError", + InvalidResultError, NULL); PyDict_SetItemString(dict, "MultipleResultsError", MultipleResultsError); /* Make the version available */ @@ -1320,16 +1319,24 @@ PyMODINIT_FUNC PyInit__pg(void) /* Transaction states */ PyDict_SetItemString(dict, "TRANS_IDLE", PyLong_FromLong(PQTRANS_IDLE)); - PyDict_SetItemString(dict, "TRANS_ACTIVE", PyLong_FromLong(PQTRANS_ACTIVE)); - PyDict_SetItemString(dict, "TRANS_INTRANS", PyLong_FromLong(PQTRANS_INTRANS)); - PyDict_SetItemString(dict, "TRANS_INERROR", PyLong_FromLong(PQTRANS_INERROR)); - PyDict_SetItemString(dict, "TRANS_UNKNOWN", PyLong_FromLong(PQTRANS_UNKNOWN)); + PyDict_SetItemString(dict, "TRANS_ACTIVE", + PyLong_FromLong(PQTRANS_ACTIVE)); + PyDict_SetItemString(dict, "TRANS_INTRANS", + PyLong_FromLong(PQTRANS_INTRANS)); + PyDict_SetItemString(dict, "TRANS_INERROR", + PyLong_FromLong(PQTRANS_INERROR)); + PyDict_SetItemString(dict, "TRANS_UNKNOWN", + PyLong_FromLong(PQTRANS_UNKNOWN)); /* Polling results */ - PyDict_SetItemString(dict, "POLLING_OK", PyLong_FromLong(PGRES_POLLING_OK)); - PyDict_SetItemString(dict, "POLLING_FAILED", PyLong_FromLong(PGRES_POLLING_FAILED)); - PyDict_SetItemString(dict, "POLLING_READING", PyLong_FromLong(PGRES_POLLING_READING)); - PyDict_SetItemString(dict, "POLLING_WRITING", PyLong_FromLong(PGRES_POLLING_WRITING)); + PyDict_SetItemString(dict, "POLLING_OK", + PyLong_FromLong(PGRES_POLLING_OK)); + PyDict_SetItemString(dict, "POLLING_FAILED", + PyLong_FromLong(PGRES_POLLING_FAILED)); + PyDict_SetItemString(dict, "POLLING_READING", + PyLong_FromLong(PGRES_POLLING_READING)); + PyDict_SetItemString(dict, "POLLING_WRITING", + PyLong_FromLong(PGRES_POLLING_WRITING)); /* Create mode for large objects */ PyDict_SetItemString(dict, "INV_READ", PyLong_FromLong(INV_READ)); diff --git a/pgnotice.c b/pgnotice.c index e079283c..0252a56f 100644 --- a/pgnotice.c +++ b/pgnotice.c @@ -25,7 +25,7 @@ notice_getattr(noticeObject *self, PyObject *nameobj) if (!strcmp(name, "pgcnx")) { if (self->pgcnx && _check_cnx_obj(self->pgcnx)) { Py_INCREF(self->pgcnx); - return (PyObject *) self->pgcnx; + return (PyObject *)self->pgcnx; } else { Py_INCREF(Py_None); @@ -54,11 +54,12 @@ notice_getattr(noticeObject *self, PyObject *nameobj) return PyUnicode_FromString(s); } else { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } } - return PyObject_GenericGetAttr((PyObject *) self, nameobj); + return PyObject_GenericGetAttr((PyObject *)self, nameobj); } /* Get the list of notice attributes. */ @@ -67,10 +68,9 @@ notice_dir(noticeObject *self, PyObject *noargs) { PyObject *attrs; - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[ssssss]", - "pgcnx", "severity", "message", "primary", "detail", "hint"); + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[ssssss]", "pgcnx", "severity", + "message", "primary", "detail", "hint"); return attrs; } @@ -84,41 +84,38 @@ notice_str(noticeObject *self) /* Notice object methods */ static struct PyMethodDef notice_methods[] = { - {"__dir__", (PyCFunction) notice_dir, METH_NOARGS, NULL}, - {NULL, NULL} -}; + {"__dir__", (PyCFunction)notice_dir, METH_NOARGS, NULL}, {NULL, NULL}}; static char notice__doc__[] = "PostgreSQL notice object"; /* Notice type definition */ static PyTypeObject noticeType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.Notice", /* tp_name */ - sizeof(noticeObject), /* tp_basicsize */ - 0, /* tp_itemsize */ + PyVarObject_HEAD_INIT(NULL, 0) "pg.Notice", /* tp_name */ + sizeof(noticeObject), /* tp_basicsize */ + 0, /* tp_itemsize */ /* methods */ - 0, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) notice_str, /* tp_str */ - (getattrofunc) notice_getattr, /* tp_getattro */ - PyObject_GenericSetAttr, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - notice__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - notice_methods, /* tp_methods */ + 0, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)notice_str, /* tp_str */ + (getattrofunc)notice_getattr, /* tp_getattro */ + PyObject_GenericSetAttr, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + notice__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + notice_methods, /* tp_methods */ }; diff --git a/pgquery.c b/pgquery.c index 194bfaa1..6346497d 100644 --- a/pgquery.c +++ b/pgquery.c @@ -37,7 +37,7 @@ query_len(PyObject *self) PyObject *tmp; Py_ssize_t len; - tmp = PyLong_FromLong(((queryObject*) self)->max_row); + tmp = PyLong_FromLong(((queryObject *)self)->max_row); len = PyLong_AsSsize_t(tmp); Py_DECREF(tmp); return len; @@ -64,18 +64,18 @@ _query_value_in_column(queryObject *self, int column) /* cast the string representation into a Python object */ if (type & PYGRES_ARRAY) return cast_array(s, - PQgetlength(self->result, self->current_row, column), - self->encoding, type, NULL, 0); + PQgetlength(self->result, self->current_row, column), + self->encoding, type, NULL, 0); if (type == PYGRES_BYTEA) return cast_bytea_text(s); if (type == PYGRES_OTHER) return cast_other(s, - PQgetlength(self->result, self->current_row, column), - self->encoding, - PQftype(self->result, column), self->pgcnx->cast_hook); + PQgetlength(self->result, self->current_row, column), + self->encoding, PQftype(self->result, column), + self->pgcnx->cast_hook); if (type & PYGRES_TEXT) - return cast_sized_text(s, - PQgetlength(self->result, self->current_row, column), + return cast_sized_text( + s, PQgetlength(self->result, self->current_row, column), self->encoding, type); return cast_unsized_simple(s, type); } @@ -94,7 +94,8 @@ _query_row_as_tuple(queryObject *self) for (j = 0; j < self->num_fields; ++j) { PyObject *val = _query_value_in_column(self, j); if (!val) { - Py_DECREF(row_tuple); return NULL; + Py_DECREF(row_tuple); + return NULL; } PyTuple_SET_ITEM(row_tuple, j, val); } @@ -108,7 +109,8 @@ _query_row_as_tuple(queryObject *self) If this is a normal query result, the query itself will be returned, otherwise a result value will be returned that shall be passed on. */ static PyObject * -_get_async_result(queryObject *self, int keep) { +_get_async_result(queryObject *self, int keep) +{ int fetch = 0; if (self->async) { @@ -118,7 +120,8 @@ _get_async_result(queryObject *self, int keep) { /* mark query as fetched, do not fetch again */ self->async = 2; } - } else if (!keep) { + } + else if (!keep) { self->async = 1; } } @@ -147,8 +150,8 @@ _get_async_result(queryObject *self, int keep) { } if ((status = PQresultStatus(self->result)) != PGRES_TUPLES_OK) { - PyObject* result = _conn_non_query_result( - status, self->result, self->pgcnx->cnx); + PyObject *result = + _conn_non_query_result(status, self->result, self->pgcnx->cnx); self->result = NULL; /* since this has been already cleared */ if (!result) { /* Raise an error. We need to call PQgetResult() to clear the @@ -181,8 +184,9 @@ _get_async_result(queryObject *self, int keep) { Py_DECREF(self); return NULL; } - } else if (self->async == 2 && - !self->max_row && !self->num_fields && !self->col_types) { + } + else if (self->async == 2 && !self->max_row && !self->num_fields && + !self->col_types) { Py_INCREF(Py_None); return Py_None; } @@ -195,14 +199,14 @@ _get_async_result(queryObject *self, int keep) { static PyObject * query_getitem(PyObject *self, Py_ssize_t i) { - queryObject *q = (queryObject *) self; + queryObject *q = (queryObject *)self; PyObject *tmp; long row; if ((tmp = _get_async_result(q, 0)) != (PyObject *)self) return tmp; - tmp = PyLong_FromSize_t((size_t) i); + tmp = PyLong_FromSize_t((size_t)i); row = PyLong_AsLong(tmp); Py_DECREF(tmp); @@ -211,13 +215,14 @@ query_getitem(PyObject *self, Py_ssize_t i) return NULL; } - q->current_row = (int) row; + q->current_row = (int)row; return _query_row_as_tuple(q); } /* __iter__() method of the queryObject: Returns the default iterator yielding rows as tuples. */ -static PyObject* query_iter(queryObject *self) +static PyObject * +query_iter(queryObject *self) { PyObject *res; @@ -226,7 +231,7 @@ static PyObject* query_iter(queryObject *self) self->current_row = 0; Py_INCREF(self); - return (PyObject*) self; + return (PyObject *)self; } /* __next__() method of the queryObject: @@ -242,13 +247,14 @@ query_next(queryObject *self, PyObject *noargs) } row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; + if (row_tuple) + ++self->current_row; return row_tuple; } /* Get number of bytes allocated for PGresult object */ static char query_memsize__doc__[] = -"memsize() -- return number of bytes allocated by query result"; + "memsize() -- return number of bytes allocated by query result"; static PyObject * query_memsize(queryObject *self, PyObject *noargs) { @@ -262,7 +268,7 @@ query_memsize(queryObject *self, PyObject *noargs) /* List field names from query result. */ static char query_listfields__doc__[] = -"listfields() -- List field names from result"; + "listfields() -- List field names from result"; static PyObject * query_listfields(queryObject *self, PyObject *noargs) @@ -285,7 +291,7 @@ query_listfields(queryObject *self, PyObject *noargs) /* Get field name from number in last result. */ static char query_fieldname__doc__[] = -"fieldname(num) -- return name of field from result from its position"; + "fieldname(num) -- return name of field from result from its position"; static PyObject * query_fieldname(queryObject *self, PyObject *args) @@ -313,7 +319,7 @@ query_fieldname(queryObject *self, PyObject *args) /* Get field number from name in last result. */ static char query_fieldnum__doc__[] = -"fieldnum(name) -- return position in query for field from its name"; + "fieldnum(name) -- return position in query for field from its name"; static PyObject * query_fieldnum(queryObject *self, PyObject *args) @@ -339,13 +345,15 @@ query_fieldnum(queryObject *self, PyObject *args) /* Build a tuple with info for query field with given number. */ static PyObject * -_query_build_field_info(PGresult *res, int col_num) { +_query_build_field_info(PGresult *res, int col_num) +{ PyObject *info; info = PyTuple_New(4); if (info) { PyTuple_SET_ITEM(info, 0, PyUnicode_FromString(PQfname(res, col_num))); - PyTuple_SET_ITEM(info, 1, PyLong_FromLong((long) PQftype(res, col_num))); + PyTuple_SET_ITEM(info, 1, + PyLong_FromLong((long)PQftype(res, col_num))); PyTuple_SET_ITEM(info, 2, PyLong_FromLong(PQfsize(res, col_num))); PyTuple_SET_ITEM(info, 3, PyLong_FromLong(PQfmod(res, col_num))); } @@ -354,7 +362,7 @@ _query_build_field_info(PGresult *res, int col_num) { /* Get information on one or all fields of the query result. */ static char query_fieldinfo__doc__[] = -"fieldinfo([name]) -- return information about field(s) in query result"; + "fieldinfo([name]) -- return information about field(s) in query result"; static PyObject * query_fieldinfo(queryObject *self, PyObject *args) @@ -374,14 +382,18 @@ query_fieldinfo(queryObject *self, PyObject *args) /* gets field number */ if (PyBytes_Check(field)) { num = PQfnumber(self->result, PyBytes_AsString(field)); - } else if (PyUnicode_Check(field)) { + } + else if (PyUnicode_Check(field)) { PyObject *tmp = get_encoded_string(field, self->encoding); - if (!tmp) return NULL; + if (!tmp) + return NULL; num = PQfnumber(self->result, PyBytes_AsString(tmp)); Py_DECREF(tmp); - } else if (PyLong_Check(field)) { - num = (int) PyLong_AsLong(field); - } else { + } + else if (PyLong_Check(field)) { + num = (int)PyLong_AsLong(field); + } + else { PyErr_SetString(PyExc_TypeError, "Field should be given as column number or name"); return NULL; @@ -407,13 +419,12 @@ query_fieldinfo(queryObject *self, PyObject *args) return result; } - /* Retrieve one row from the result as a tuple. */ static char query_one__doc__[] = -"one() -- Get one row from the result of a query\n\n" -"Only one row from the result is returned as a tuple of fields.\n" -"This method can be called multiple times to return more rows.\n" -"It returns None if the result does not contain one more row.\n"; + "one() -- Get one row from the result of a query\n\n" + "Only one row from the result is returned as a tuple of fields.\n" + "This method can be called multiple times to return more rows.\n" + "It returns None if the result does not contain one more row.\n"; static PyObject * query_one(queryObject *self, PyObject *noargs) @@ -421,13 +432,14 @@ query_one(queryObject *self, PyObject *noargs) PyObject *row_tuple; if ((row_tuple = _get_async_result(self, 0)) == (PyObject *)self) { - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; + if (row_tuple) + ++self->current_row; } return row_tuple; @@ -435,11 +447,13 @@ query_one(queryObject *self, PyObject *noargs) /* Retrieve the single row from the result as a tuple. */ static char query_single__doc__[] = -"single() -- Get the result of a query as single row\n\n" -"The single row from the query result is returned as a tuple of fields.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; + "single() -- Get the result of a query as single row\n\n" + "The single row from the query result is returned as a tuple of fields.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; static PyObject * query_single(queryObject *self, PyObject *noargs) @@ -447,7 +461,6 @@ query_single(queryObject *self, PyObject *noargs) PyObject *row_tuple; if ((row_tuple = _get_async_result(self, 0)) == (PyObject *)self) { - if (self->max_row != 1) { if (self->max_row) set_error_msg(MultipleResultsError, "Multiple results found"); @@ -458,7 +471,8 @@ query_single(queryObject *self, PyObject *noargs) self->current_row = 0; row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; + if (row_tuple) + ++self->current_row; } return row_tuple; @@ -466,9 +480,9 @@ query_single(queryObject *self, PyObject *noargs) /* Retrieve the last query result as a list of tuples. */ static char query_getresult__doc__[] = -"getresult() -- Get the result of a query\n\n" -"The result is returned as a list of rows, each one a tuple of fields\n" -"in the order returned by the server.\n"; + "getresult() -- Get the result of a query\n\n" + "The result is returned as a list of rows, each one a tuple of fields\n" + "in the order returned by the server.\n"; static PyObject * query_getresult(queryObject *self, PyObject *noargs) @@ -477,7 +491,6 @@ query_getresult(queryObject *self, PyObject *noargs) int i; if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { - if (!(result_list = PyList_New(self->max_row))) { return NULL; } @@ -486,7 +499,8 @@ query_getresult(queryObject *self, PyObject *noargs) PyObject *row_tuple = query_next(self, noargs); if (!row_tuple) { - Py_DECREF(result_list); return NULL; + Py_DECREF(result_list); + return NULL; } PyList_SET_ITEM(result_list, i, row_tuple); } @@ -510,7 +524,8 @@ _query_row_as_dict(queryObject *self) PyObject *val = _query_value_in_column(self, j); if (!val) { - Py_DECREF(row_dict); return NULL; + Py_DECREF(row_dict); + return NULL; } PyDict_SetItemString(row_dict, PQfname(self->result, j), val); Py_DECREF(val); @@ -531,17 +546,18 @@ query_next_dict(queryObject *self, PyObject *noargs) } row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; + if (row_dict) + ++self->current_row; return row_dict; } /* Retrieve one row from the result as a dictionary. */ static char query_onedict__doc__[] = -"onedict() -- Get one row from the result of a query\n\n" -"Only one row from the result is returned as a dictionary with\n" -"the field names used as the keys.\n" -"This method can be called multiple times to return more rows.\n" -"It returns None if the result does not contain one more row.\n"; + "onedict() -- Get one row from the result of a query\n\n" + "Only one row from the result is returned as a dictionary with\n" + "the field names used as the keys.\n" + "This method can be called multiple times to return more rows.\n" + "It returns None if the result does not contain one more row.\n"; static PyObject * query_onedict(queryObject *self, PyObject *noargs) @@ -549,13 +565,14 @@ query_onedict(queryObject *self, PyObject *noargs) PyObject *row_dict; if ((row_dict = _get_async_result(self, 0)) == (PyObject *)self) { - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; + if (row_dict) + ++self->current_row; } return row_dict; @@ -563,12 +580,14 @@ query_onedict(queryObject *self, PyObject *noargs) /* Retrieve the single row from the result as a dictionary. */ static char query_singledict__doc__[] = -"singledict() -- Get the result of a query as single row\n\n" -"The single row from the query result is returned as a dictionary with\n" -"the field names used as the keys.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; + "singledict() -- Get the result of a query as single row\n\n" + "The single row from the query result is returned as a dictionary with\n" + "the field names used as the keys.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; static PyObject * query_singledict(queryObject *self, PyObject *noargs) @@ -576,7 +595,6 @@ query_singledict(queryObject *self, PyObject *noargs) PyObject *row_dict; if ((row_dict = _get_async_result(self, 0)) == (PyObject *)self) { - if (self->max_row != 1) { if (self->max_row) set_error_msg(MultipleResultsError, "Multiple results found"); @@ -587,7 +605,8 @@ query_singledict(queryObject *self, PyObject *noargs) self->current_row = 0; row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; + if (row_dict) + ++self->current_row; } return row_dict; @@ -595,9 +614,9 @@ query_singledict(queryObject *self, PyObject *noargs) /* Retrieve the last query result as a list of dictionaries. */ static char query_dictresult__doc__[] = -"dictresult() -- Get the result of a query\n\n" -"The result is returned as a list of rows, each one a dictionary with\n" -"the field names used as the keys.\n"; + "dictresult() -- Get the result of a query\n\n" + "The result is returned as a list of rows, each one a dictionary with\n" + "the field names used as the keys.\n"; static PyObject * query_dictresult(queryObject *self, PyObject *noargs) @@ -606,7 +625,6 @@ query_dictresult(queryObject *self, PyObject *noargs) int i; if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { - if (!(result_list = PyList_New(self->max_row))) { return NULL; } @@ -615,7 +633,8 @@ query_dictresult(queryObject *self, PyObject *noargs) PyObject *row_dict = query_next_dict(self, noargs); if (!row_dict) { - Py_DECREF(result_list); return NULL; + Py_DECREF(result_list); + return NULL; } PyList_SET_ITEM(result_list, i, row_dict); } @@ -626,9 +645,9 @@ query_dictresult(queryObject *self, PyObject *noargs) /* Retrieve last result as iterator of dictionaries. */ static char query_dictiter__doc__[] = -"dictiter() -- Get the result of a query\n\n" -"The result is returned as an iterator of rows, each one a a dictionary\n" -"with the field names used as the keys.\n"; + "dictiter() -- Get the result of a query\n\n" + "The result is returned as an iterator of rows, each one a a dictionary\n" + "with the field names used as the keys.\n"; static PyObject * query_dictiter(queryObject *self, PyObject *noargs) @@ -647,10 +666,10 @@ query_dictiter(queryObject *self, PyObject *noargs) /* Retrieve one row from the result as a named tuple. */ static char query_onenamed__doc__[] = -"onenamed() -- Get one row from the result of a query\n\n" -"Only one row from the result is returned as a named tuple of fields.\n" -"This method can be called multiple times to return more rows.\n" -"It returns None if the result does not contain one more row.\n"; + "onenamed() -- Get one row from the result of a query\n\n" + "Only one row from the result is returned as a named tuple of fields.\n" + "This method can be called multiple times to return more rows.\n" + "It returns None if the result does not contain one more row.\n"; static PyObject * query_onenamed(queryObject *self, PyObject *noargs) @@ -665,7 +684,8 @@ query_onenamed(queryObject *self, PyObject *noargs) return res; if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } return PyObject_CallFunction(namednext, "(O)", self); @@ -673,11 +693,14 @@ query_onenamed(queryObject *self, PyObject *noargs) /* Retrieve the single row from the result as a tuple. */ static char query_singlenamed__doc__[] = -"singlenamed() -- Get the result of a query as single row\n\n" -"The single row from the query result is returned as named tuple of fields.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; + "singlenamed() -- Get the result of a query as single row\n\n" + "The single row from the query result is returned as named tuple of " + "fields.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; static PyObject * query_singlenamed(queryObject *self, PyObject *noargs) @@ -705,9 +728,10 @@ query_singlenamed(queryObject *self, PyObject *noargs) /* Retrieve last result as list of named tuples. */ static char query_namedresult__doc__[] = -"namedresult() -- Get the result of a query\n\n" -"The result is returned as a list of rows, each one a named tuple of fields\n" -"in the order returned by the server.\n"; + "namedresult() -- Get the result of a query\n\n" + "The result is returned as a list of rows, each one a named tuple of " + "fields\n" + "in the order returned by the server.\n"; static PyObject * query_namedresult(queryObject *self, PyObject *noargs) @@ -720,8 +744,10 @@ query_namedresult(queryObject *self, PyObject *noargs) if ((res_list = _get_async_result(self, 1)) == (PyObject *)self) { res = PyObject_CallFunction(namediter, "(O)", self); - if (!res) return NULL; - if (PyList_Check(res)) return res; + if (!res) + return NULL; + if (PyList_Check(res)) + return res; res_list = PySequence_List(res); Py_DECREF(res); } @@ -731,9 +757,9 @@ query_namedresult(queryObject *self, PyObject *noargs) /* Retrieve last result as iterator of named tuples. */ static char query_namediter__doc__[] = -"namediter() -- Get the result of a query\n\n" -"The result is returned as an iterator of rows, each one a named tuple\n" -"of fields in the order returned by the server.\n"; + "namediter() -- Get the result of a query\n\n" + "The result is returned as an iterator of rows, each one a named tuple\n" + "of fields in the order returned by the server.\n"; static PyObject * query_namediter(queryObject *self, PyObject *noargs) @@ -745,11 +771,12 @@ query_namediter(queryObject *self, PyObject *noargs) } if ((res_iter = _get_async_result(self, 1)) == (PyObject *)self) { - res = PyObject_CallFunction(namediter, "(O)", self); - if (!res) return NULL; - if (!PyList_Check(res)) return res; - res_iter = (Py_TYPE(res)->tp_iter)((PyObject *) self); + if (!res) + return NULL; + if (!PyList_Check(res)) + return res; + res_iter = (Py_TYPE(res)->tp_iter)((PyObject *)self); Py_DECREF(res); } @@ -758,9 +785,9 @@ query_namediter(queryObject *self, PyObject *noargs) /* Retrieve the last query result as a list of scalar values. */ static char query_scalarresult__doc__[] = -"scalarresult() -- Get query result as scalars\n\n" -"The result is returned as a list of scalar values where the values\n" -"are the first fields of the rows in the order returned by the server.\n"; + "scalarresult() -- Get query result as scalars\n\n" + "The result is returned as a list of scalar values where the values\n" + "are the first fields of the rows in the order returned by the server.\n"; static PyObject * query_scalarresult(queryObject *self, PyObject *noargs) @@ -768,7 +795,6 @@ query_scalarresult(queryObject *self, PyObject *noargs) PyObject *result_list; if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { - if (!self->num_fields) { set_error_msg(ProgrammingError, "No fields in result"); return NULL; @@ -778,14 +804,13 @@ query_scalarresult(queryObject *self, PyObject *noargs) return NULL; } - for (self->current_row = 0; - self->current_row < self->max_row; - ++self->current_row) - { + for (self->current_row = 0; self->current_row < self->max_row; + ++self->current_row) { PyObject *value = _query_value_in_column(self, 0); if (!value) { - Py_DECREF(result_list); return NULL; + Py_DECREF(result_list); + return NULL; } PyList_SET_ITEM(result_list, self->current_row, value); } @@ -796,9 +821,9 @@ query_scalarresult(queryObject *self, PyObject *noargs) /* Retrieve the last query result as iterator of scalar values. */ static char query_scalariter__doc__[] = -"scalariter() -- Get query result as scalars\n\n" -"The result is returned as an iterator of scalar values where the values\n" -"are the first fields of the rows in the order returned by the server.\n"; + "scalariter() -- Get query result as scalars\n\n" + "The result is returned as an iterator of scalar values where the values\n" + "are the first fields of the rows in the order returned by the server.\n"; static PyObject * query_scalariter(queryObject *self, PyObject *noargs) @@ -822,10 +847,12 @@ query_scalariter(queryObject *self, PyObject *noargs) /* Retrieve one result as scalar value. */ static char query_onescalar__doc__[] = -"onescalar() -- Get one scalar value from the result of a query\n\n" -"Returns the first field of the next row from the result as a scalar value.\n" -"This method can be called multiple times to return more rows as scalars.\n" -"It returns None if the result does not contain one more row.\n"; + "onescalar() -- Get one scalar value from the result of a query\n\n" + "Returns the first field of the next row from the result as a scalar " + "value.\n" + "This method can be called multiple times to return more rows as " + "scalars.\n" + "It returns None if the result does not contain one more row.\n"; static PyObject * query_onescalar(queryObject *self, PyObject *noargs) @@ -833,18 +860,19 @@ query_onescalar(queryObject *self, PyObject *noargs) PyObject *value; if ((value = _get_async_result(self, 0)) == (PyObject *)self) { - if (!self->num_fields) { set_error_msg(ProgrammingError, "No fields in result"); return NULL; } if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } value = _query_value_in_column(self, 0); - if (value) ++self->current_row; + if (value) + ++self->current_row; } return value; @@ -852,11 +880,14 @@ query_onescalar(queryObject *self, PyObject *noargs) /* Retrieves the single row from the result as a tuple. */ static char query_singlescalar__doc__[] = -"singlescalar() -- Get scalar value from single result of a query\n\n" -"Returns the first field of the next row from the result as a scalar value.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; + "singlescalar() -- Get scalar value from single result of a query\n\n" + "Returns the first field of the next row from the result as a scalar " + "value.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; static PyObject * query_singlescalar(queryObject *self, PyObject *noargs) @@ -864,7 +895,6 @@ query_singlescalar(queryObject *self, PyObject *noargs) PyObject *value; if ((value = _get_async_result(self, 0)) == (PyObject *)self) { - if (!self->num_fields) { set_error_msg(ProgrammingError, "No fields in result"); return NULL; @@ -880,7 +910,8 @@ query_singlescalar(queryObject *self, PyObject *noargs) self->current_row = 0; value = _query_value_in_column(self, 0); - if (value) ++self->current_row; + if (value) + ++self->current_row; } return value; @@ -888,92 +919,86 @@ query_singlescalar(queryObject *self, PyObject *noargs) /* Query sequence protocol methods */ static PySequenceMethods query_sequence_methods = { - (lenfunc) query_len, /* sq_length */ - 0, /* sq_concat */ - 0, /* sq_repeat */ - (ssizeargfunc) query_getitem, /* sq_item */ - 0, /* sq_ass_item */ - 0, /* sq_contains */ - 0, /* sq_inplace_concat */ - 0, /* sq_inplace_repeat */ + (lenfunc)query_len, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + (ssizeargfunc)query_getitem, /* sq_item */ + 0, /* sq_ass_item */ + 0, /* sq_contains */ + 0, /* sq_inplace_concat */ + 0, /* sq_inplace_repeat */ }; /* Query object methods */ static struct PyMethodDef query_methods[] = { - {"getresult", (PyCFunction) query_getresult, - METH_NOARGS, query_getresult__doc__}, - {"dictresult", (PyCFunction) query_dictresult, - METH_NOARGS, query_dictresult__doc__}, - {"dictiter", (PyCFunction) query_dictiter, - METH_NOARGS, query_dictiter__doc__}, - {"namedresult", (PyCFunction) query_namedresult, - METH_NOARGS, query_namedresult__doc__}, - {"namediter", (PyCFunction) query_namediter, - METH_NOARGS, query_namediter__doc__}, - {"one", (PyCFunction) query_one, - METH_NOARGS, query_one__doc__}, - {"single", (PyCFunction) query_single, - METH_NOARGS, query_single__doc__}, - {"onedict", (PyCFunction) query_onedict, - METH_NOARGS, query_onedict__doc__}, - {"singledict", (PyCFunction) query_singledict, - METH_NOARGS, query_singledict__doc__}, - {"onenamed", (PyCFunction) query_onenamed, - METH_NOARGS, query_onenamed__doc__}, - {"singlenamed", (PyCFunction) query_singlenamed, - METH_NOARGS, query_singlenamed__doc__}, - {"scalarresult", (PyCFunction) query_scalarresult, - METH_NOARGS, query_scalarresult__doc__}, - {"scalariter", (PyCFunction) query_scalariter, - METH_NOARGS, query_scalariter__doc__}, - {"onescalar", (PyCFunction) query_onescalar, - METH_NOARGS, query_onescalar__doc__}, - {"singlescalar", (PyCFunction) query_singlescalar, - METH_NOARGS, query_singlescalar__doc__}, - {"fieldname", (PyCFunction) query_fieldname, - METH_VARARGS, query_fieldname__doc__}, - {"fieldnum", (PyCFunction) query_fieldnum, - METH_VARARGS, query_fieldnum__doc__}, - {"listfields", (PyCFunction) query_listfields, - METH_NOARGS, query_listfields__doc__}, - {"fieldinfo", (PyCFunction) query_fieldinfo, - METH_VARARGS, query_fieldinfo__doc__}, - {"memsize", (PyCFunction) query_memsize, - METH_NOARGS, query_memsize__doc__}, - {NULL, NULL} -}; + {"getresult", (PyCFunction)query_getresult, METH_NOARGS, + query_getresult__doc__}, + {"dictresult", (PyCFunction)query_dictresult, METH_NOARGS, + query_dictresult__doc__}, + {"dictiter", (PyCFunction)query_dictiter, METH_NOARGS, + query_dictiter__doc__}, + {"namedresult", (PyCFunction)query_namedresult, METH_NOARGS, + query_namedresult__doc__}, + {"namediter", (PyCFunction)query_namediter, METH_NOARGS, + query_namediter__doc__}, + {"one", (PyCFunction)query_one, METH_NOARGS, query_one__doc__}, + {"single", (PyCFunction)query_single, METH_NOARGS, query_single__doc__}, + {"onedict", (PyCFunction)query_onedict, METH_NOARGS, query_onedict__doc__}, + {"singledict", (PyCFunction)query_singledict, METH_NOARGS, + query_singledict__doc__}, + {"onenamed", (PyCFunction)query_onenamed, METH_NOARGS, + query_onenamed__doc__}, + {"singlenamed", (PyCFunction)query_singlenamed, METH_NOARGS, + query_singlenamed__doc__}, + {"scalarresult", (PyCFunction)query_scalarresult, METH_NOARGS, + query_scalarresult__doc__}, + {"scalariter", (PyCFunction)query_scalariter, METH_NOARGS, + query_scalariter__doc__}, + {"onescalar", (PyCFunction)query_onescalar, METH_NOARGS, + query_onescalar__doc__}, + {"singlescalar", (PyCFunction)query_singlescalar, METH_NOARGS, + query_singlescalar__doc__}, + {"fieldname", (PyCFunction)query_fieldname, METH_VARARGS, + query_fieldname__doc__}, + {"fieldnum", (PyCFunction)query_fieldnum, METH_VARARGS, + query_fieldnum__doc__}, + {"listfields", (PyCFunction)query_listfields, METH_NOARGS, + query_listfields__doc__}, + {"fieldinfo", (PyCFunction)query_fieldinfo, METH_VARARGS, + query_fieldinfo__doc__}, + {"memsize", (PyCFunction)query_memsize, METH_NOARGS, query_memsize__doc__}, + {NULL, NULL}}; static char query__doc__[] = "PyGreSQL query object"; /* Query type definition */ static PyTypeObject queryType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.Query", /* tp_name */ - sizeof(queryObject), /* tp_basicsize */ - 0, /* tp_itemsize */ + PyVarObject_HEAD_INIT(NULL, 0) "pg.Query", /* tp_name */ + sizeof(queryObject), /* tp_basicsize */ + 0, /* tp_itemsize */ /* methods */ - (destructor) query_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - &query_sequence_methods, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) query_str, /* tp_str */ - PyObject_GenericGetAttr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - query__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - (getiterfunc) query_iter, /* tp_iter */ - (iternextfunc) query_next, /* tp_iternext */ - query_methods, /* tp_methods */ + (destructor)query_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + &query_sequence_methods, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)query_str, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + query__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + (getiterfunc)query_iter, /* tp_iter */ + (iternextfunc)query_next, /* tp_iternext */ + query_methods, /* tp_methods */ }; diff --git a/pgsource.c b/pgsource.c index 7b081273..73c9a52b 100644 --- a/pgsource.c +++ b/pgsource.c @@ -71,7 +71,7 @@ source_getattr(sourceObject *self, PyObject *nameobj) if (!strcmp(name, "pgcnx")) { if (_check_source_obj(self, 0)) { Py_INCREF(self->pgcnx); - return (PyObject *) (self->pgcnx); + return (PyObject *)(self->pgcnx); } Py_INCREF(Py_None); return Py_None; @@ -94,7 +94,7 @@ source_getattr(sourceObject *self, PyObject *nameobj) return PyLong_FromLong(self->num_fields); /* seeks name in methods (fallback) */ - return PyObject_GenericGetAttr((PyObject *) self, nameobj); + return PyObject_GenericGetAttr((PyObject *)self, nameobj); } /* Set source object attributes. */ @@ -119,8 +119,9 @@ source_setattr(sourceObject *self, char *name, PyObject *v) /* Close object. */ static char source_close__doc__[] = -"close() -- close query object without deleting it\n\n" -"All instances of the query object can no longer be used after this call.\n"; + "close() -- close query object without deleting it\n\n" + "All instances of the query object can no longer be used after this " + "call.\n"; static PyObject * source_close(sourceObject *self, PyObject *noargs) @@ -141,15 +142,15 @@ source_close(sourceObject *self, PyObject *noargs) /* Database query. */ static char source_execute__doc__[] = -"execute(sql) -- execute a SQL statement (string)\n\n" -"On success, this call returns the number of affected rows, or None\n" -"for DQL (SELECT, ...) statements. The fetch (fetch(), fetchone()\n" -"and fetchall()) methods can be used to get result rows.\n"; + "execute(sql) -- execute a SQL statement (string)\n\n" + "On success, this call returns the number of affected rows, or None\n" + "for DQL (SELECT, ...) statements. The fetch (fetch(), fetchone()\n" + "and fetchall()) methods can be used to get result rows.\n"; static PyObject * source_execute(sourceObject *self, PyObject *sql) { - PyObject *tmp_obj = NULL; /* auxiliary string object */ + PyObject *tmp_obj = NULL; /* auxiliary string object */ char *query; int encoding; @@ -165,7 +166,8 @@ source_execute(sourceObject *self, PyObject *sql) } else if (PyUnicode_Check(sql)) { tmp_obj = get_encoded_string(sql, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ query = PyBytes_AsString(tmp_obj); } else { @@ -205,30 +207,29 @@ source_execute(sourceObject *self, PyObject *sql) /* checks result status */ switch (PQresultStatus(self->result)) { /* query succeeded */ - case PGRES_TUPLES_OK: /* DQL: returns None (DB-SIG compliant) */ + case PGRES_TUPLES_OK: /* DQL: returns None (DB-SIG compliant) */ self->result_type = RESULT_DQL; self->max_row = PQntuples(self->result); self->num_fields = PQnfields(self->result); Py_INCREF(Py_None); return Py_None; - case PGRES_COMMAND_OK: /* other requests */ + case PGRES_COMMAND_OK: /* other requests */ case PGRES_COPY_OUT: - case PGRES_COPY_IN: - { - long num_rows; - char *tmp; - - tmp = PQcmdTuples(self->result); - if (tmp[0]) { - self->result_type = RESULT_DML; - num_rows = atol(tmp); - } - else { - self->result_type = RESULT_DDL; - num_rows = -1; - } - return PyLong_FromLong(num_rows); + case PGRES_COPY_IN: { + long num_rows; + char *tmp; + + tmp = PQcmdTuples(self->result); + if (tmp[0]) { + self->result_type = RESULT_DML; + num_rows = atol(tmp); } + else { + self->result_type = RESULT_DDL; + num_rows = -1; + } + return PyLong_FromLong(num_rows); + } /* query failed */ case PGRES_EMPTY_QUERY: @@ -238,7 +239,7 @@ source_execute(sourceObject *self, PyObject *sql) case PGRES_FATAL_ERROR: case PGRES_NONFATAL_ERROR: set_error(ProgrammingError, "Cannot execute command", - self->pgcnx->cnx, self->result); + self->pgcnx->cnx, self->result); break; default: set_error_msg(InternalError, @@ -254,7 +255,7 @@ source_execute(sourceObject *self, PyObject *sql) /* Get oid status for last query (valid for INSERTs, 0 for other). */ static char source_oidstatus__doc__[] = -"oidstatus() -- return oid of last inserted row (if available)"; + "oidstatus() -- return oid of last inserted row (if available)"; static PyObject * source_oidstatus(sourceObject *self, PyObject *noargs) @@ -272,14 +273,14 @@ source_oidstatus(sourceObject *self, PyObject *noargs) return Py_None; } - return PyLong_FromLong((long) oid); + return PyLong_FromLong((long)oid); } /* Fetch rows from last result. */ static char source_fetch__doc__[] = -"fetch(num) -- return the next num rows from the last result in a list\n\n" -"If num parameter is omitted arraysize attribute value is used.\n" -"If size equals -1, all rows are fetched.\n"; + "fetch(num) -- return the next num rows from the last result in a list\n\n" + "If num parameter is omitted arraysize attribute value is used.\n" + "If size equals -1, all rows are fetched.\n"; static PyObject * source_fetch(sourceObject *self, PyObject *args) @@ -309,7 +310,8 @@ source_fetch(sourceObject *self, PyObject *args) } /* allocate list for result */ - if (!(res_list = PyList_New(0))) return NULL; + if (!(res_list = PyList_New(0))) + return NULL; encoding = self->encoding; @@ -319,7 +321,8 @@ source_fetch(sourceObject *self, PyObject *args) int j; if (!(rowtuple = PyTuple_New(self->num_fields))) { - Py_DECREF(res_list); return NULL; + Py_DECREF(res_list); + return NULL; } for (j = 0; j < self->num_fields; ++j) { @@ -345,7 +348,9 @@ source_fetch(sourceObject *self, PyObject *args) } if (PyList_Append(res_list, rowtuple)) { - Py_DECREF(rowtuple); Py_DECREF(res_list); return NULL; + Py_DECREF(rowtuple); + Py_DECREF(res_list); + return NULL; } Py_DECREF(rowtuple); } @@ -387,7 +392,7 @@ _source_move(sourceObject *self, int move) /* Move to first result row. */ static char source_movefirst__doc__[] = -"movefirst() -- move to first result row"; + "movefirst() -- move to first result row"; static PyObject * source_movefirst(sourceObject *self, PyObject *noargs) @@ -397,7 +402,7 @@ source_movefirst(sourceObject *self, PyObject *noargs) /* Move to last result row. */ static char source_movelast__doc__[] = -"movelast() -- move to last valid result row"; + "movelast() -- move to last valid result row"; static PyObject * source_movelast(sourceObject *self, PyObject *noargs) @@ -406,8 +411,7 @@ source_movelast(sourceObject *self, PyObject *noargs) } /* Move to next result row. */ -static char source_movenext__doc__[] = -"movenext() -- move to next result row"; +static char source_movenext__doc__[] = "movenext() -- move to next result row"; static PyObject * source_movenext(sourceObject *self, PyObject *noargs) @@ -417,7 +421,7 @@ source_movenext(sourceObject *self, PyObject *noargs) /* Move to previous result row. */ static char source_moveprev__doc__[] = -"moveprev() -- move to previous result row"; + "moveprev() -- move to previous result row"; static PyObject * source_moveprev(sourceObject *self, PyObject *noargs) @@ -427,17 +431,17 @@ source_moveprev(sourceObject *self, PyObject *noargs) /* Put copy data. */ static char source_putdata__doc__[] = -"putdata(buffer) -- send data to server during copy from stdin"; + "putdata(buffer) -- send data to server during copy from stdin"; static PyObject * source_putdata(sourceObject *self, PyObject *buffer) { - PyObject *tmp_obj = NULL; /* an auxiliary object */ - char *buf; /* the buffer as encoded string */ - Py_ssize_t nbytes; /* length of string */ - char *errormsg = NULL; /* error message */ - int res; /* direct result of the operation */ - PyObject *ret; /* return value */ + PyObject *tmp_obj = NULL; /* an auxiliary object */ + char *buf; /* the buffer as encoded string */ + Py_ssize_t nbytes; /* length of string */ + char *errormsg = NULL; /* error message */ + int res; /* direct result of the operation */ + PyObject *ret; /* return value */ /* checks validity */ if (!_check_source_obj(self, CHECK_CNX)) { @@ -459,9 +463,10 @@ source_putdata(sourceObject *self, PyObject *buffer) } else if (PyUnicode_Check(buffer)) { /* or pass a unicode string */ - tmp_obj = get_encoded_string( - buffer, PQclientEncoding(self->pgcnx->cnx)); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + tmp_obj = + get_encoded_string(buffer, PQclientEncoding(self->pgcnx->cnx)); + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &buf, &nbytes); } else if (PyErr_GivenExceptionMatches(buffer, PyExc_BaseException)) { @@ -470,10 +475,11 @@ source_putdata(sourceObject *self, PyObject *buffer) if (PyUnicode_Check(tmp_obj)) { PyObject *obj = tmp_obj; - tmp_obj = get_encoded_string( - obj, PQclientEncoding(self->pgcnx->cnx)); + tmp_obj = + get_encoded_string(obj, PQclientEncoding(self->pgcnx->cnx)); Py_DECREF(obj); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ } errormsg = PyBytes_AsString(tmp_obj); buf = NULL; @@ -487,8 +493,7 @@ source_putdata(sourceObject *self, PyObject *buffer) /* checks validity */ if (!_check_source_obj(self, CHECK_CNX | CHECK_RESULT) || - PQresultStatus(self->result) != PGRES_COPY_IN) - { + PQresultStatus(self->result) != PGRES_COPY_IN) { PyErr_SetString(PyExc_IOError, "Connection is invalid or not in copy_in state"); Py_XDECREF(tmp_obj); @@ -496,7 +501,7 @@ source_putdata(sourceObject *self, PyObject *buffer) } if (buf) { - res = nbytes ? PQputCopyData(self->pgcnx->cnx, buf, (int) nbytes) : 1; + res = nbytes ? PQputCopyData(self->pgcnx->cnx, buf, (int)nbytes) : 1; } else { res = PQputCopyEnd(self->pgcnx->cnx, errormsg); @@ -513,7 +518,7 @@ source_putdata(sourceObject *self, PyObject *buffer) ret = Py_None; Py_INCREF(ret); } - else { /* copy is done */ + else { /* copy is done */ PGresult *result; /* final result of the operation */ Py_BEGIN_ALLOW_THREADS; @@ -529,7 +534,8 @@ source_putdata(sourceObject *self, PyObject *buffer) ret = PyLong_FromLong(num_rows); } else { - if (!errormsg) errormsg = PQerrorMessage(self->pgcnx->cnx); + if (!errormsg) + errormsg = PQerrorMessage(self->pgcnx->cnx); PyErr_SetString(PyExc_IOError, errormsg); ret = NULL; } @@ -544,15 +550,15 @@ source_putdata(sourceObject *self, PyObject *buffer) /* Get copy data. */ static char source_getdata__doc__[] = -"getdata(decode) -- receive data to server during copy to stdout"; + "getdata(decode) -- receive data to server during copy to stdout"; static PyObject * source_getdata(sourceObject *self, PyObject *args) { - int *decode = 0; /* decode flag */ - char *buffer; /* the copied buffer as encoded byte string */ - Py_ssize_t nbytes; /* length of the byte string */ - PyObject *ret; /* return value */ + int *decode = 0; /* decode flag */ + char *buffer; /* the copied buffer as encoded byte string */ + Py_ssize_t nbytes; /* length of the byte string */ + PyObject *ret; /* return value */ /* checks validity */ if (!_check_source_obj(self, CHECK_CNX)) { @@ -570,8 +576,7 @@ source_getdata(sourceObject *self, PyObject *args) /* checks validity */ if (!_check_source_obj(self, CHECK_CNX | CHECK_RESULT) || - PQresultStatus(self->result) != PGRES_COPY_OUT) - { + PQresultStatus(self->result) != PGRES_COPY_OUT) { PyErr_SetString(PyExc_IOError, "Connection is invalid or not in copy_out state"); return NULL; @@ -584,7 +589,7 @@ source_getdata(sourceObject *self, PyObject *args) return NULL; } - if (nbytes == -1) { /* copy is done */ + if (nbytes == -1) { /* copy is done */ PGresult *result; /* final result of the operation */ Py_BEGIN_ALLOW_THREADS; @@ -609,9 +614,9 @@ source_getdata(sourceObject *self, PyObject *args) self->result_type = RESULT_EMPTY; } else { /* a row has been returned */ - ret = decode ? get_decoded_string( - buffer, nbytes, PQclientEncoding(self->pgcnx->cnx)) : - PyBytes_FromStringAndSize(buffer, nbytes); + ret = decode ? get_decoded_string(buffer, nbytes, + PQclientEncoding(self->pgcnx->cnx)) + : PyBytes_FromStringAndSize(buffer, nbytes); PQfreemem(buffer); } @@ -633,7 +638,7 @@ _source_fieldindex(sourceObject *self, PyObject *param, const char *usage) num = PQfnumber(self->result, PyBytes_AsString(param)); } else if (PyLong_Check(param)) { - num = (int) PyLong_AsLong(param); + num = (int)PyLong_AsLong(param); } else { PyErr_SetString(PyExc_TypeError, usage); @@ -664,20 +669,18 @@ _source_buildinfo(sourceObject *self, int num) /* affects field information */ PyTuple_SET_ITEM(result, 0, PyLong_FromLong(num)); PyTuple_SET_ITEM(result, 1, - PyUnicode_FromString(PQfname(self->result, num))); + PyUnicode_FromString(PQfname(self->result, num))); PyTuple_SET_ITEM(result, 2, - PyLong_FromLong((long) PQftype(self->result, num))); - PyTuple_SET_ITEM(result, 3, - PyLong_FromLong(PQfsize(self->result, num))); - PyTuple_SET_ITEM(result, 4, - PyLong_FromLong(PQfmod(self->result, num))); + PyLong_FromLong((long)PQftype(self->result, num))); + PyTuple_SET_ITEM(result, 3, PyLong_FromLong(PQfsize(self->result, num))); + PyTuple_SET_ITEM(result, 4, PyLong_FromLong(PQfmod(self->result, num))); return result; } /* Lists fields info. */ static char source_listinfo__doc__[] = -"listinfo() -- get information for all fields (position, name, type oid)"; + "listinfo() -- get information for all fields (position, name, type oid)"; static PyObject * source_listInfo(sourceObject *self, PyObject *noargs) @@ -710,7 +713,7 @@ source_listInfo(sourceObject *self, PyObject *noargs) /* List fields information for last result. */ static char source_fieldinfo__doc__[] = -"fieldinfo(desc) -- get specified field info (position, name, type oid)"; + "fieldinfo(desc) -- get specified field info (position, name, type oid)"; static PyObject * source_fieldinfo(sourceObject *self, PyObject *desc) @@ -719,9 +722,9 @@ source_fieldinfo(sourceObject *self, PyObject *desc) /* checks args and validity */ if ((num = _source_fieldindex( - self, desc, - "Method fieldinfo() needs a string or integer as argument")) == -1) - { + self, desc, + "Method fieldinfo() needs a string or integer as argument")) == + -1) { return NULL; } @@ -731,7 +734,7 @@ source_fieldinfo(sourceObject *self, PyObject *desc) /* Retrieve field value. */ static char source_field__doc__[] = -"field(desc) -- return specified field value"; + "field(desc) -- return specified field value"; static PyObject * source_field(sourceObject *self, PyObject *desc) @@ -740,9 +743,8 @@ source_field(sourceObject *self, PyObject *desc) /* checks args and validity */ if ((num = _source_fieldindex( - self, desc, - "Method field() needs a string or integer as argument")) == -1) - { + self, desc, + "Method field() needs a string or integer as argument")) == -1) { return NULL; } @@ -756,78 +758,70 @@ source_dir(connObject *self, PyObject *noargs) { PyObject *attrs; - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[sssss]", - "pgcnx", "arraysize", "resulttype", "ntuples", "nfields"); + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[sssss]", "pgcnx", "arraysize", + "resulttype", "ntuples", "nfields"); return attrs; } /* Source object methods */ static PyMethodDef source_methods[] = { - {"__dir__", (PyCFunction) source_dir, METH_NOARGS, NULL}, - - {"close", (PyCFunction) source_close, - METH_NOARGS, source_close__doc__}, - {"execute", (PyCFunction) source_execute, - METH_O, source_execute__doc__}, - {"oidstatus", (PyCFunction) source_oidstatus, - METH_NOARGS, source_oidstatus__doc__}, - {"fetch", (PyCFunction) source_fetch, - METH_VARARGS, source_fetch__doc__}, - {"movefirst", (PyCFunction) source_movefirst, - METH_NOARGS, source_movefirst__doc__}, - {"movelast", (PyCFunction) source_movelast, - METH_NOARGS, source_movelast__doc__}, - {"movenext", (PyCFunction) source_movenext, - METH_NOARGS, source_movenext__doc__}, - {"moveprev", (PyCFunction) source_moveprev, - METH_NOARGS, source_moveprev__doc__}, - {"putdata", (PyCFunction) source_putdata, - METH_O, source_putdata__doc__}, - {"getdata", (PyCFunction) source_getdata, - METH_VARARGS, source_getdata__doc__}, - {"field", (PyCFunction) source_field, - METH_O, source_field__doc__}, - {"fieldinfo", (PyCFunction) source_fieldinfo, - METH_O, source_fieldinfo__doc__}, - {"listinfo", (PyCFunction) source_listInfo, - METH_NOARGS, source_listinfo__doc__}, - {NULL, NULL} -}; + {"__dir__", (PyCFunction)source_dir, METH_NOARGS, NULL}, + + {"close", (PyCFunction)source_close, METH_NOARGS, source_close__doc__}, + {"execute", (PyCFunction)source_execute, METH_O, source_execute__doc__}, + {"oidstatus", (PyCFunction)source_oidstatus, METH_NOARGS, + source_oidstatus__doc__}, + {"fetch", (PyCFunction)source_fetch, METH_VARARGS, source_fetch__doc__}, + {"movefirst", (PyCFunction)source_movefirst, METH_NOARGS, + source_movefirst__doc__}, + {"movelast", (PyCFunction)source_movelast, METH_NOARGS, + source_movelast__doc__}, + {"movenext", (PyCFunction)source_movenext, METH_NOARGS, + source_movenext__doc__}, + {"moveprev", (PyCFunction)source_moveprev, METH_NOARGS, + source_moveprev__doc__}, + {"putdata", (PyCFunction)source_putdata, METH_O, source_putdata__doc__}, + {"getdata", (PyCFunction)source_getdata, METH_VARARGS, + source_getdata__doc__}, + {"field", (PyCFunction)source_field, METH_O, source_field__doc__}, + {"fieldinfo", (PyCFunction)source_fieldinfo, METH_O, + source_fieldinfo__doc__}, + {"listinfo", (PyCFunction)source_listInfo, METH_NOARGS, + source_listinfo__doc__}, + {NULL, NULL}}; static char source__doc__[] = "PyGreSQL source object"; /* Source type definition */ static PyTypeObject sourceType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pgdb.Source", /* tp_name */ - sizeof(sourceObject), /* tp_basicsize */ - 0, /* tp_itemsize */ + PyVarObject_HEAD_INIT(NULL, 0) "pgdb.Source", /* tp_name */ + sizeof(sourceObject), /* tp_basicsize */ + 0, /* tp_itemsize */ /* methods */ - (destructor) source_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - (setattrfunc) source_setattr, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) source_str, /* tp_str */ - (getattrofunc) source_getattr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - source__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - source_methods, /* tp_methods */ + (destructor)source_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + (setattrfunc)source_setattr, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)source_str, /* tp_str */ + (getattrofunc)source_getattr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + source__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + source_methods, /* tp_methods */ }; diff --git a/tox.ini b/tox.ini index 9ddc3a75..37b3a39d 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11},ruff,docs +envlist = py3{7,8,9,10,11},ruff,cformat,docs [testenv:ruff] basepython = python3.11 @@ -9,6 +9,13 @@ deps = ruff>=0.0.287 commands = ruff setup.py pg.py pgdb.py tests +[testenv:cformat] +basepython = python3.11 +allowlist_externals = + sh +commands = + sh -c "! (clang-format --style=file -n *.c 2>&1 | tee /dev/tty | grep format-violations)" + [testenv:docs] basepython = python3.11 deps = From a6c38643170a7037bd93c3290921e632051f1b67 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 3 Sep 2023 15:47:13 +0200 Subject: [PATCH 149/194] Add type hints to the pg module --- docs/contents/pg/db_wrapper.rst | 18 +- docs/contents/pg/query.rst | 4 +- pg.py | 1115 ++++++++++++++++-------------- pgdb.py | 172 ++--- pgmodule.c | 8 +- pyproject.toml | 15 + tests/config.py | 6 +- tests/dbapi20.py | 7 +- tests/test_classic_connection.py | 92 +-- tests/test_classic_dbwrapper.py | 272 ++++---- tests/test_classic_functions.py | 54 +- tests/test_classic_largeobj.py | 3 + tests/test_dbapi20_copy.py | 6 +- tests/test_tutorial.py | 3 +- tox.ini | 6 + 15 files changed, 969 insertions(+), 812 deletions(-) diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 68d33c65..64710456 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -58,7 +58,7 @@ pkey -- return the primary key of a table Return the primary key of a table :param str table: name of table - :returns: Name of the field which is the primary key of the table + :returns: Name of the field that is the primary key of the table :rtype: str :raises KeyError: the table does not have a primary key @@ -67,6 +67,22 @@ returned as strings unless you set the composite flag. Composite primary keys are always represented as tuples. Note that this raises a KeyError if the table does not have a primary key. +pkeys -- return the primary keys of a table +------------------------------------------- + +.. method:: DB.pkeys(table) + + Return the primary keys of a table as a tuple + + :param str table: name of table + :returns: Names of the fields that are the primary keys of the table + :rtype: tuple + :raises KeyError: the table does not have a primary key + +This method returns the primary keys of a table as a tuple, i.e. +single primary keys are also returned as a tuple with one item. +Note that this raises a KeyError if the table does not have a primary key. + get_databases -- get list of databases in the system ---------------------------------------------------- diff --git a/docs/contents/pg/query.rst b/docs/contents/pg/query.rst index 3232c115..fcee193f 100644 --- a/docs/contents/pg/query.rst +++ b/docs/contents/pg/query.rst @@ -336,10 +336,10 @@ listfields -- list field names of query result List field names of query result :returns: field names - :rtype: list + :rtype: tuple :raises TypeError: too many parameters -This method returns the list of field names defined for the query result. +This method returns the tuple of field names defined for the query result. The fields are in the same order as the result values. fieldname, fieldnum -- field name/number conversion diff --git a/pg.py b/pg.py index d29cb5c2..11aaf90a 100644 --- a/pg.py +++ b/pg.py @@ -20,9 +20,11 @@ For a DB-API 2 compliant interface use the newer pgdb module. """ +from __future__ import annotations + import select import weakref -from collections import OrderedDict, namedtuple +from collections import namedtuple from contextlib import suppress from datetime import date, datetime, time, timedelta from decimal import Decimal @@ -34,7 +36,18 @@ from operator import itemgetter from re import compile as regex from types import MappingProxyType -from typing import Callable, ClassVar, Dict, List, Mapping, Type, Union +from typing import ( + Any, + Callable, + ClassVar, + Generator, + Iterator, + List, + Mapping, + NamedTuple, + Sequence, + TypeVar, +) from uuid import UUID try: @@ -49,15 +62,16 @@ if os.path.exists(os.path.join(path, libpq))] if sys.version_info >= (3, 8): # see https://docs.python.org/3/whatsnew/3.8.html#ctypes + add_dll_dir = os.add_dll_directory # type: ignore for path in paths: - with os.add_dll_directory(os.path.abspath(path)): + with add_dll_dir(os.path.abspath(path)): try: - from _pg import version + from _pg import version # type: ignore except ImportError: pass else: del version - e = None + e = None # type: ignore break if paths: libpq = 'compatible ' + libpq @@ -86,6 +100,7 @@ TRANS_INERROR, TRANS_INTRANS, TRANS_UNKNOWN, + Connection, DatabaseError, DataError, Error, @@ -98,6 +113,7 @@ NotSupportedError, OperationalError, ProgrammingError, + Query, Warning, cast_array, cast_hstore, @@ -148,6 +164,7 @@ 'InvalidResultError', 'MultipleResultsError', 'NoResultError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', + 'Query', 'INV_READ', 'INV_WRITE', 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', @@ -164,21 +181,24 @@ 'set_defbase', 'set_defhost', 'set_defopt', 'set_defpasswd', 'set_defport', 'set_defuser', 'set_jsondecode', 'set_query_helpers', 'set_typecast', - 'version', '__version__'] + 'version', '__version__', +] # Auxiliary classes and functions that are independent of a DB connection: -def get_args(func): +def get_args(func: Callable) -> list: return list(signature(func).parameters) # time zones used in Postgres timestamptz output -_timezones = dict(CET='+0100', EET='+0200', EST='-0500', - GMT='+0000', HST='-1000', MET='+0100', MST='-0700', - UCT='+0000', UTC='+0000', WET='+0000') +_timezones: dict[str, str] = { + 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', + 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', + 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' +} -def _timezone_as_offset(tz): +def _timezone_as_offset(tz: str) -> str: if tz.startswith(('+', '-')): if len(tz) < 5: return tz + '00' @@ -186,7 +206,7 @@ def _timezone_as_offset(tz): return _timezones.get(tz, '+0000') -def _oid_key(table): +def _oid_key(table: str) -> str: """Build oid key from a table name.""" return f'oid({table})' @@ -201,7 +221,7 @@ class Hstore(dict): _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') @classmethod - def _quote(cls, s): + def _quote(cls, s: Any) -> str: if s is None: return 'NULL' if not isinstance(s, str): @@ -213,7 +233,7 @@ def _quote(cls, s): s = f'"{s}"' return s - def __str__(self): + def __str__(self) -> str: """Create a printable representation of the hstore value.""" q = self._quote return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) @@ -222,12 +242,12 @@ def __str__(self): class Json: """Wrapper class for marking Json values.""" - def __init__(self, obj, encode=None): + def __init__(self, obj: Any, encode: Callable | None = None) -> None: """Initialize the JSON object.""" self.obj = obj self.encode = encode or jsonencode - def __str__(self): + def __str__(self) -> str: """Create a printable representation of the JSON object.""" obj = self.obj if isinstance(obj, str): @@ -241,7 +261,7 @@ class _SimpleTypes(dict): The corresponding Python types and simple names are also mapped. """ - _type_aliases: Mapping[str, List[Union[str, type]]] = MappingProxyType({ + _type_aliases: Mapping[str, list[str | type]] = MappingProxyType({ 'bool': [bool], 'bytea': [Bytea], 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', @@ -256,7 +276,7 @@ class _SimpleTypes(dict): }) # noinspection PyMissingConstructor - def __init__(self): + def __init__(self) -> None: """Initialize type mapping.""" for typ, keys in self._type_aliases.items(): keys = [typ, *keys] @@ -265,24 +285,24 @@ def __init__(self): if isinstance(key, str): self[f'_{key}'] = f'{typ}[]' elif not isinstance(key, tuple): - self[List[key]] = f'{typ}[]' + self[List[key]] = f'{typ}[]' # type: ignore @staticmethod - def __missing__(key): + def __missing__(key: str) -> str: """Unmapped types are interpreted as text.""" return 'text' - def get_type_dict(self): + def get_type_dict(self) -> dict[type, str]: """Get a plain dictionary of only the types.""" - return dict((key, typ) for key, typ in self.items() - if not isinstance(key, (str, tuple))) + return {key: typ for key, typ in self.items() + if not isinstance(key, (str, tuple))} _simpletypes = _SimpleTypes() _simple_type_dict = _simpletypes.get_type_dict() -def _quote_if_unqualified(param, name): +def _quote_if_unqualified(param: str, name: int | str) -> str: """Quote parameter representing a qualified name. Puts a quote_ident() call around the given parameter unless @@ -300,7 +320,7 @@ class _ParameterList(list): adapt: Callable - def add(self, value, typ=None): + def add(self, value: Any, typ:Any = None) -> str: """Typecast value with known database type and build parameter list. If this is a literal value, it will be returned as is. Otherwise, a @@ -318,29 +338,29 @@ class Literal(str): """Wrapper class for marking literal SQL values.""" -class AttrDict(OrderedDict): +class AttrDict(dict): """Simple read-only ordered dictionary for storing attribute names.""" - def __init__(self, *args, **kw): + def __init__(self, *args: Any, **kw: Any) -> None: self._read_only = False - OrderedDict.__init__(self, *args, **kw) + super().__init__(*args, **kw) self._read_only = True error = self._read_only_error - self.clear = self.update = error - self.pop = self.setdefault = self.popitem = error + self.clear = self.update = error # type: ignore + self.pop = self.setdefault = self.popitem = error # type: ignore - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: if self._read_only: self._read_only_error() - OrderedDict.__setitem__(self, key, value) + super().__setitem__(key, value) - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: if self._read_only: self._read_only_error() - OrderedDict.__delitem__(self, key) + super().__delitem__(key) @staticmethod - def _read_only_error(*args, **kw): + def _read_only_error(*_args: Any, **_kw: Any) -> Any: raise TypeError('This object is read-only') @@ -357,12 +377,12 @@ class Adapter: _re_record_quote = regex(r'[(,"\\]') _re_array_escape = _re_record_escape = regex(r'(["\\])') - def __init__(self, db): + def __init__(self, db: DB): """Initialize the adapter object with the given connection.""" self.db = weakref.proxy(db) @classmethod - def _adapt_bool(cls, v): + def _adapt_bool(cls, v: Any) -> str | None: """Adapt a boolean parameter.""" if isinstance(v, str): if not v: @@ -371,7 +391,7 @@ def _adapt_bool(cls, v): return 't' if v else 'f' @classmethod - def _adapt_date(cls, v): + def _adapt_date(cls, v: Any) -> Any: """Adapt a date parameter.""" if not v: return None @@ -380,7 +400,7 @@ def _adapt_date(cls, v): return v @staticmethod - def _adapt_num(v): + def _adapt_num(v: Any) -> Any: """Adapt a numeric parameter.""" if not v and v != 0: return None @@ -388,11 +408,11 @@ def _adapt_num(v): _adapt_int = _adapt_float = _adapt_money = _adapt_num - def _adapt_bytea(self, v): + def _adapt_bytea(self, v: Any) -> str: """Adapt a bytea parameter.""" return self.db.escape_bytea(v) - def _adapt_json(self, v): + def _adapt_json(self, v: Any) -> str | None: """Adapt a json parameter.""" if not v: return None @@ -402,7 +422,7 @@ def _adapt_json(self, v): return str(v) return self.db.encode_json(v) - def _adapt_hstore(self, v): + def _adapt_hstore(self, v: Any) -> str | None: """Adapt a hstore parameter.""" if not v: return None @@ -414,7 +434,7 @@ def _adapt_hstore(self, v): return str(Hstore(v)) raise TypeError(f'Hstore parameter {v} has wrong type') - def _adapt_uuid(self, v): + def _adapt_uuid(self, v: Any) -> str | None: """Adapt a UUID parameter.""" if not v: return None @@ -423,7 +443,7 @@ def _adapt_uuid(self, v): return str(v) @classmethod - def _adapt_text_array(cls, v): + def _adapt_text_array(cls, v: Any) -> str: """Adapt a text type array parameter.""" if isinstance(v, list): adapt = cls._adapt_text_array @@ -441,7 +461,7 @@ def _adapt_text_array(cls, v): _adapt_date_array = _adapt_text_array @classmethod - def _adapt_bool_array(cls, v): + def _adapt_bool_array(cls, v: Any) -> str: """Adapt a boolean array parameter.""" if isinstance(v, list): adapt = cls._adapt_bool_array @@ -455,7 +475,7 @@ def _adapt_bool_array(cls, v): return 't' if v else 'f' @classmethod - def _adapt_num_array(cls, v): + def _adapt_num_array(cls, v: Any) -> str: """Adapt a numeric array parameter.""" if isinstance(v, list): adapt = cls._adapt_num_array @@ -467,7 +487,7 @@ def _adapt_num_array(cls, v): _adapt_int_array = _adapt_float_array = _adapt_money_array = \ _adapt_num_array - def _adapt_bytea_array(self, v): + def _adapt_bytea_array(self, v: Any) -> bytes: """Adapt a bytea array parameter.""" if isinstance(v, list): return b'{' + b','.join( @@ -476,7 +496,7 @@ def _adapt_bytea_array(self, v): return b'null' return self.db.escape_bytea(v).replace(b'\\', b'\\\\') - def _adapt_json_array(self, v): + def _adapt_json_array(self, v: Any) -> str: """Adapt a json array parameter.""" if isinstance(v, list): adapt = self._adapt_json_array @@ -490,7 +510,7 @@ def _adapt_json_array(self, v): v = f'"{v}"' return v - def _adapt_record(self, v, typ): + def _adapt_record(self, v: Any, typ: Any) -> str: """Adapt a record parameter with given type.""" typ = self.get_attnames(typ).values() if len(typ) != len(v): @@ -516,7 +536,7 @@ def _adapt_record(self, v, typ): v = ','.join(value) return f'({v})' - def adapt(self, value, typ=None): + def adapt(self, value: Any, typ: Any = None) -> str: """Adapt a value with known database type.""" if value is not None and not isinstance(value, Literal): if typ: @@ -541,14 +561,14 @@ def adapt(self, value, typ=None): return value @staticmethod - def simple_type(name): + def simple_type(name: str) -> DbType: """Create a simple database type with given attribute names.""" typ = DbType(name) typ.simple = name return typ @staticmethod - def get_simple_name(typ): + def get_simple_name(typ: Any) -> str: """Get the simple name of a database type.""" if isinstance(typ, DbType): # noinspection PyUnresolvedReferences @@ -556,14 +576,14 @@ def get_simple_name(typ): return _simpletypes[typ] @staticmethod - def get_attnames(typ): + def get_attnames(typ: Any) -> dict[str, dict[str, str]]: """Get the attribute names of a composite database type.""" if isinstance(typ, DbType): return typ.attnames return {} @classmethod - def guess_simple_type(cls, value): + def guess_simple_type(cls, value: Any) -> str | None: """Try to guess which database type the given value has.""" # optimize for most frequent types try: @@ -597,16 +617,17 @@ def guess_simple_type(cls, value): guess = cls.guess_simple_type # noinspection PyUnusedLocal - def get_attnames(self): - return AttrDict((str(n + 1), simple_type(guess(v))) + def get_attnames(self: DbType) -> AttrDict: + return AttrDict((str(n + 1), simple_type(guess(v) or 'text')) for n, v in enumerate(value)) typ = simple_type('record') typ._get_attnames = get_attnames return typ + return None @classmethod - def guess_simple_base_type(cls, value): + def guess_simple_base_type(cls, value: Any) -> str | None: """Try to guess the base type of a given array.""" for v in value: if isinstance(v, list): @@ -615,8 +636,9 @@ def guess_simple_base_type(cls, value): typ = cls.guess_simple_type(v) if typ: return typ + return None - def adapt_inline(self, value, nested=False): + def adapt_inline(self, value: Any, nested: bool=False) -> Any: """Adapt a value that is put into the SQL and needs to be quoted.""" if value is None: return 'NULL' @@ -661,7 +683,7 @@ def adapt_inline(self, value, nested=False): value = self.adapt_inline(value) return value - def parameter_list(self): + def parameter_list(self) -> _ParameterList: """Return a parameter list for parameters with known database types. The list has an add(value, typ) method that will build up the @@ -671,7 +693,11 @@ def parameter_list(self): params.adapt = self.adapt return params - def format_query(self, command, values=None, types=None, inline=False): + def format_query(self, command: str, + values: list | tuple | dict | None = None, + types: list | tuple | dict | None = None, + inline: bool=False + ) -> tuple[str, _ParameterList]: """Format a database query using the given values and types. The optional types describe the values and must be passed as a list, @@ -681,15 +707,15 @@ def format_query(self, command, values=None, types=None, inline=False): If inline is set to True, then parameters will be passed inline together with the query string. """ + params = self.parameter_list() if not values: - return command, [] + return command, params if inline and types: raise ValueError('Typed parameters must be sent separately') - params = self.parameter_list() if isinstance(values, (list, tuple)): if inline: adapt = self.adapt_inline - literals = [adapt(value) for value in values] + seq_literals = [adapt(value) for value in values] else: add = params.add if types: @@ -698,52 +724,51 @@ def format_query(self, command, values=None, types=None, inline=False): if (not isinstance(types, (list, tuple)) or len(types) != len(values)): raise TypeError('The values and types do not match') - literals = [add(value, typ) - for value, typ in zip(values, types)] + seq_literals = [add(value, typ) + for value, typ in zip(values, types)] else: - literals = [add(value) for value in values] - command %= tuple(literals) + seq_literals = [add(value) for value in values] + command %= tuple(seq_literals) elif isinstance(values, dict): # we want to allow extra keys in the dictionary, # so we first must find the values actually used in the command used_values = {} - literals = dict.fromkeys(values, '') + map_literals = dict.fromkeys(values, '') for key in values: - del literals[key] + del map_literals[key] try: - command % literals + command % map_literals except KeyError: - used_values[key] = values[key] - literals[key] = '' - values = used_values + used_values[key] = values[key] # pyright: ignore + map_literals[key] = '' if inline: adapt = self.adapt_inline - literals = {key: adapt(value) - for key, value in values.items()} + map_literals = {key: adapt(value) + for key, value in used_values.items()} else: add = params.add if types: if not isinstance(types, dict): raise TypeError('The values and types do not match') - literals = {key: add(values[key], types.get(key)) - for key in sorted(values)} + map_literals = {key: add(used_values[key], types.get(key)) + for key in sorted(used_values)} else: - literals = {key: add(values[key]) - for key in sorted(values)} - command %= literals + map_literals = {key: add(used_values[key]) + for key in sorted(used_values)} + command %= map_literals else: raise TypeError('The values must be passed as tuple, list or dict') return command, params -def cast_bool(value): +def cast_bool(value: str) -> Any: """Cast a boolean value.""" if not get_bool(): return value return value[0] == 't' -def cast_json(value): +def cast_json(value: str) -> Any: """Cast a JSON value.""" cast = get_jsondecode() if not cast: @@ -751,12 +776,12 @@ def cast_json(value): return cast(value) -def cast_num(value): +def cast_num(value: str) -> Any: """Cast a numeric value.""" return (get_decimal() or float)(value) -def cast_money(value): +def cast_money(value: str) -> Any: """Cast a money value.""" point = get_decimal_point() if not point: @@ -768,12 +793,12 @@ def cast_money(value): return (get_decimal() or float)(value) -def cast_int2vector(value): +def cast_int2vector(value: str) -> list[int]: """Cast an int2vector value.""" return [int(v) for v in value.split()] -def cast_date(value, connection): +def cast_date(value: str, connection: DB) -> Any: """Cast a date value.""" # The output format depends on the server setting DateStyle. The default # setting ISO and the setting for German are actually unambiguous. The @@ -783,93 +808,93 @@ def cast_date(value, connection): return date.min if value == 'infinity': return date.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return date.min - value = value[0] + value = values[0] if len(value) > 10: return date.max - fmt = connection.date_format() - return datetime.strptime(value, fmt).date() + format = connection.date_format() + return datetime.strptime(value, format).date() -def cast_time(value): +def cast_time(value: str) -> Any: """Cast a time value.""" - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - return datetime.strptime(value, fmt).time() + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + return datetime.strptime(value, format).time() _re_timezone = regex('(.*)([+-].*)') -def cast_timetz(value): +def cast_timetz(value: str) -> Any: """Cast a timetz value.""" - tz = _re_timezone.match(value) - if tz: - value, tz = tz.groups() + m = _re_timezone.match(value) + if m: + value, tz = m.groups() else: tz = '+0000' - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' value += _timezone_as_offset(tz) - fmt += '%z' - return datetime.strptime(value, fmt).timetz() + format += '%z' + return datetime.strptime(value, format).timetz() -def cast_timestamp(value, connection): +def cast_timestamp(value: str, connection: DB) -> Any: """Cast a timestamp value.""" if value == '-infinity': return datetime.min if value == 'infinity': return datetime.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:5] - if len(value[3]) > 4: + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:5] + if len(values[3]) > 4: return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] else: - if len(value[0]) > 10: + if len(values[0]) > 10: return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - return datetime.strptime(' '.join(value), ' '.join(fmt)) + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + return datetime.strptime(' '.join(values), ' '.join(formats)) -def cast_timestamptz(value, connection): +def cast_timestamptz(value: str, connection: DB) -> Any: """Cast a timestamptz value.""" if value == '-infinity': return datetime.min if value == 'infinity': return datetime.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:] - if len(value[3]) > 4: + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:] + if len(values[3]) > 4: return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] - value, tz = value[:-1], value[-1] + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + values, tz = values[:-1], values[-1] else: - if fmt.startswith('%Y-'): - tz = _re_timezone.match(value[1]) - if tz: - value[1], tz = tz.groups() + if format.startswith('%Y-'): + m = _re_timezone.match(values[1]) + if m: + values[1], tz = m.groups() else: tz = '+0000' else: - value, tz = value[:-1], value[-1] - if len(value[0]) > 10: + values, tz = values[:-1], values[-1] + if len(values[0]) > 10: return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - value.append(_timezone_as_offset(tz)) - fmt.append('%z') - return datetime.strptime(' '.join(value), ' '.join(fmt)) + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + values.append(_timezone_as_offset(tz)) + formats.append('%z') + return datetime.strptime(' '.join(values), ' '.join(formats)) _re_interval_sql_standard = regex( @@ -900,37 +925,37 @@ def cast_timestamptz(value, connection): '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') -def cast_interval(value): +def cast_interval(value: str) -> timedelta: """Cast an interval value.""" # The output format depends on the server setting IntervalStyle, but it's # not necessary to consult this setting to parse it. It's faster to just # check all possible formats, and there is no ambiguity here. m = _re_interval_iso_8601.match(value) if m: - m = [d or '0' for d in m.groups()] - secs_ago = m.pop(5) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + secs_ago = s.pop(5) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if secs_ago: secs = -secs usecs = -usecs else: m = _re_interval_postgres_verbose.match(value) if m: - m, ago = [d or '0' for d in m.groups()[:8]], m.group(9) - secs_ago = m.pop(5) == '-' - m = [-int(d) for d in m] if ago else [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) + secs_ago = s.pop(5) == '-' + d = [-int(v) for v in s] if ago else [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if secs_ago: secs = - secs usecs = -usecs else: m = _re_interval_postgres.match(value) if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if hours_ago: hours = -hours mins = -mins @@ -939,11 +964,11 @@ def cast_interval(value): else: m = _re_interval_sql_standard.match(value) if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - years_ago = m.pop(0) == '-' - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + years_ago = s.pop(0) == '-' + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if years_ago: years = -years mons = -mons @@ -973,7 +998,7 @@ class Typecasts(dict): # the default cast functions # (str functions are ignored but have been added for faster access) - defaults: ClassVar[Dict[str, Type]] = { + defaults: ClassVar[dict[str, Callable]] = { 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, @@ -985,11 +1010,11 @@ class Typecasts(dict): 'time': cast_time, 'timetz': cast_timetz, 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, 'int2vector': cast_int2vector, 'uuid': UUID, - 'anyarray': cast_array, 'record': cast_record} + 'anyarray': cast_array, 'record': cast_record} # pyright: ignore - connection = None # will be set in a connection specific instance + connection: DB | None = None # set in a connection specific instance - def __missing__(self, typ): + def __missing__(self, typ: Any) -> Callable | None: """Create a cast function if it is not cached. Note that this class never raises a KeyError, @@ -997,7 +1022,7 @@ def __missing__(self, typ): """ if not isinstance(typ, str): raise TypeError(f'Invalid type: {typ}') - cast = self.defaults.get(typ) + cast: Callable | None = self.defaults.get(typ) if cast: # store default for faster access cast = self._add_connection(cast) @@ -1016,7 +1041,7 @@ def __missing__(self, typ): return cast @staticmethod - def _needs_connection(func): + def _needs_connection(func: Callable) -> bool: """Check if a typecast function needs a connection argument.""" try: args = get_args(func) @@ -1025,17 +1050,17 @@ def _needs_connection(func): else: return 'connection' in args[1:] - def _add_connection(self, cast): + def _add_connection(self, cast: Callable) -> Callable: """Add a connection argument to the typecast function if necessary.""" if not self.connection or not self._needs_connection(cast): return cast return partial(cast, connection=self.connection) - def get(self, typ, default=None): + def get(self, typ: Any, default: Any = None) -> Any: """Get the typecast function for the given database type.""" return self[typ] or default - def set(self, typ, cast): + def set(self, typ: Any, cast: Callable) -> None: """Set a typecast function for the specified database type(s).""" if isinstance(typ, str): typ = [typ] @@ -1050,7 +1075,7 @@ def set(self, typ, cast): self[t] = self._add_connection(cast) self.pop(f'_{t}', None) - def reset(self, typ=None): + def reset(self, typ: Any = None) -> None: """Reset the typecasts for the specified type(s) to their defaults. When no type is specified, all typecasts will be reset. @@ -1064,12 +1089,12 @@ def reset(self, typ=None): self.pop(t, None) @classmethod - def get_default(cls, typ): + def get_default(cls, typ: Any) -> Any: """Get the default typecast function for the given database type.""" return cls.defaults.get(typ) @classmethod - def set_default(cls, typ, cast): + def set_default(cls, typ: Any, cast: Callable | None) -> None: """Set a default typecast function for the given database type(s).""" if isinstance(typ, str): typ = [typ] @@ -1086,46 +1111,47 @@ def set_default(cls, typ, cast): defaults.pop(f'_{t}', None) # noinspection PyMethodMayBeStatic,PyUnusedLocal - def get_attnames(self, typ): + def get_attnames(self, typ: Any) -> AttrDict: """Return the fields for the given record type. This method will be replaced with the get_attnames() method of DbTypes. """ - return {} + return AttrDict() # noinspection PyMethodMayBeStatic - def dateformat(self): + def dateformat(self) -> str: """Return the current date format. This method will be replaced with the dateformat() method of DbTypes. """ return '%Y-%m-%d' - def create_array_cast(self, basecast): + def create_array_cast(self, basecast: Callable) -> Callable: """Create an array typecast for the given base cast.""" cast_array = self['anyarray'] - def cast(v): + def cast(v: Any) -> Callable: return cast_array(v, basecast) return cast - def create_record_cast(self, name, fields, casts): + def create_record_cast(self, name: str, fields: AttrDict, + casts: list[Callable]) -> Callable: """Create a named record typecast for the given fields and casts.""" cast_record = self['record'] - record = namedtuple(name, fields) + record = namedtuple(name, fields) # type: ignore - def cast(v): + def cast(v: Any) -> record: # noinspection PyArgumentList return record(*cast_record(v, casts)) return cast -def get_typecast(typ): +def get_typecast(typ: Any) -> Callable | None: """Get the global typecast function for the given database type(s).""" return Typecasts.get_default(typ) -def set_typecast(typ, cast): +def set_typecast(typ: Any, cast: Callable | None) -> None: """Set a global typecast function for the given database type(s). Note that connections cache cast functions. To be sure a global change @@ -1161,10 +1187,10 @@ class DbType(str): delim: str relid: int - _get_attnames: Callable + _get_attnames: Callable[[DbType], AttrDict] @property - def attnames(self): + def attnames(self) -> AttrDict: """Get names and types of the fields of a composite type.""" # noinspection PyUnresolvedReferences return self._get_attnames(self) @@ -1180,13 +1206,13 @@ class DbTypes(dict): _num_types = frozenset('int float num money int2 int4 int8' ' float4 float8 numeric money'.split()) - def __init__(self, db): + def __init__(self, db: DB) -> None: """Initialize type cache for connection.""" super().__init__() self._db = weakref.proxy(db) self._regtypes = False self._typecasts = Typecasts() - self._typecasts.get_attnames = self.get_attnames + self._typecasts.get_attnames = self.get_attnames # type: ignore self._typecasts.connection = self._db self._query_pg_type = ( "SELECT oid, typname, oid::pg_catalog.regtype," @@ -1194,8 +1220,9 @@ def __init__(self, db): " FROM pg_catalog.pg_type" " WHERE oid OPERATOR(pg_catalog.=) {}::pg_catalog.regtype") - def add(self, oid, pgtype, regtype, - typlen, typtype, category, delim, relid): + def add(self, oid: int, pgtype: str, regtype: str, + typlen: int, typtype: str, category: str, delim: str, relid: int + ) -> DbType: """Create a PostgreSQL type name with additional info.""" if oid in self: return self[oid] @@ -1210,14 +1237,14 @@ def add(self, oid, pgtype, regtype, typ.category = category typ.delim = delim typ.relid = relid - typ._get_attnames = self.get_attnames + typ._get_attnames = self.get_attnames # type: ignore return typ - def __missing__(self, key): + def __missing__(self, key: int | str) -> DbType: """Get the type info from the database if it is not cached.""" try: - q = self._query_pg_type.format(_quote_if_unqualified('$1', key)) - res = self._db.query(q, (key,)).getresult() + cmd = self._query_pg_type.format(_quote_if_unqualified('$1', key)) + res = self._db.query(cmd, (key,)).getresult() except ProgrammingError: res = None if not res: @@ -1227,14 +1254,14 @@ def __missing__(self, key): self[typ.oid] = self[typ.pgtype] = typ return typ - def get(self, key, default=None): + def get(self, key: int | str, default: Any = None) -> Any: """Get the type even if it is not cached.""" try: return self[key] except KeyError: return default - def get_attnames(self, typ): + def get_attnames(self, typ: Any) -> AttrDict | None: """Get names and types of the fields of a composite type.""" if not isinstance(typ, DbType): typ = self.get(typ) @@ -1244,19 +1271,19 @@ def get_attnames(self, typ): return None return self._db.get_attnames(typ.relid, with_oid=False) - def get_typecast(self, typ): + def get_typecast(self, typ: Any) -> Callable: """Get the typecast function for the given database type.""" return self._typecasts.get(typ) - def set_typecast(self, typ, cast): + def set_typecast(self, typ: Any, cast: Callable) -> None: """Set a typecast function for the specified database type(s).""" self._typecasts.set(typ, cast) - def reset_typecast(self, typ=None): + def reset_typecast(self, typ: Any = None) -> None: """Reset the typecast function for the specified database type(s).""" self._typecasts.reset(typ) - def typecast(self, value, typ): + def typecast(self, value: Any, typ: Any) -> Callable | None: """Cast the given value according to the given database type.""" if value is None: # for NULL values, no typecast is necessary @@ -1272,25 +1299,22 @@ def typecast(self, value, typ): return cast(value) -_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$') - - # The result rows for database operations are returned as named tuples # by default. Since creating namedtuple classes is a somewhat expensive # operation, we cache up to 1024 of these classes by default. # noinspection PyUnresolvedReferences @lru_cache(maxsize=1024) -def _row_factory(names): +def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: """Get a namedtuple factory for row results with the given names.""" try: - return namedtuple('Row', names, rename=True)._make + return namedtuple('Row', names, rename=True)._make # type: ignore except ValueError: # there is still a problem with the field names names = [f'column_{n}' for n in range(len(names))] - return namedtuple('Row', names)._make + return namedtuple('Row', names)._make # type: ignore -def set_row_factory_size(maxsize): +def set_row_factory_size(maxsize: int | None) -> None: """Change the size of the namedtuple factory cache. If maxsize is set to None, the cache can grow without bound. @@ -1302,26 +1326,26 @@ def set_row_factory_size(maxsize): # Helper functions used by the query object -def _dictiter(q): +def _dictiter(q: Query) -> Generator[dict[str, Any], None, None]: """Get query result as an iterator of dictionaries.""" - fields = q.listfields() + fields: tuple[str, ...] = q.listfields() for r in q: yield dict(zip(fields, r)) -def _namediter(q): +def _namediter(q: Query) -> Generator[NamedTuple, None, None]: """Get query result as an iterator of named tuples.""" row = _row_factory(q.listfields()) for r in q: yield row(r) -def _namednext(q): +def _namednext(q: Query) -> NamedTuple: """Get next row from query result as a named tuple.""" return _row_factory(q.listfields())(next(q)) -def _scalariter(q): +def _scalariter(q: Query) -> Generator[Any, None, None]: """Get query result as an iterator of scalar values.""" for r in q: yield r[0] @@ -1330,36 +1354,41 @@ def _scalariter(q): class _MemoryQuery: """Class that embodies a given query result.""" - def __init__(self, result, fields): + result: Any + fields: tuple[str, ...] + + def __init__(self, result: Any, fields: Sequence[str]) -> None: """Create query from given result rows and field names.""" self.result = result self.fields = tuple(fields) - def listfields(self): + def listfields(self) -> tuple[str, ...]: """Return the stored field names of this query.""" return self.fields - def getresult(self): + def getresult(self) -> Any: """Return the stored result of this query.""" return self.result - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return iter(self.result) -def _db_error(msg, cls=DatabaseError): +E = TypeVar('E', bound=DatabaseError) + +def _db_error(msg: str, cls:type[E] = DatabaseError) -> type[E]: """Return DatabaseError with empty sqlstate attribute.""" error = cls(msg) error.sqlstate = None return error -def _int_error(msg): +def _int_error(msg: str) -> InternalError: """Return InternalError.""" return _db_error(msg, InternalError) -def _prg_error(msg): +def _prg_error(msg: str) -> ProgrammingError: """Return ProgrammingError.""" return _db_error(msg, ProgrammingError) @@ -1376,8 +1405,10 @@ def _prg_error(msg): class NotificationHandler: """A PostgreSQL client-side asynchronous notification handler.""" - def __init__(self, db, event, callback=None, - arg_dict=None, timeout=None, stop_event=None): + def __init__(self, db: DB, event: str, callback: Callable, + arg_dict: dict | None = None, + timeout: int | float | None = None, + stop_event: str | None = None): """Initialize the notification handler. You must pass a PyGreSQL database connection, the name of an @@ -1395,7 +1426,7 @@ def __init__(self, db, event, callback=None, the handler to stop listening as stop_event. By default, it will be the event name prefixed with 'stop_'. """ - self.db = db + self.db: DB | None = db self.event = event self.stop_event = stop_event or f'stop_{event}' self.listening = False @@ -1405,32 +1436,35 @@ def __init__(self, db, event, callback=None, self.arg_dict = arg_dict self.timeout = timeout - def __del__(self): + def __del__(self) -> None: """Delete the notification handler.""" self.unlisten() - def close(self): + def close(self) -> None: """Stop listening and close the connection.""" if self.db: self.unlisten() self.db.close() self.db = None - def listen(self): + def listen(self) -> None: """Start listening for the event and the stop event.""" - if not self.listening: - self.db.query(f'listen "{self.event}"') - self.db.query(f'listen "{self.stop_event}"') + db = self.db + if db and not self.listening: + db.query(f'listen "{self.event}"') + db.query(f'listen "{self.stop_event}"') self.listening = True - def unlisten(self): + def unlisten(self) -> None: """Stop listening for the event and the stop event.""" - if self.listening: - self.db.query(f'unlisten "{self.event}"') - self.db.query(f'unlisten "{self.stop_event}"') + db = self.db + if db and self.listening: + db.query(f'unlisten "{self.event}"') + db.query(f'unlisten "{self.stop_event}"') self.listening = False - def notify(self, db=None, stop=False, payload=None): + def notify(self, db: DB | None = None, stop: bool = False, + payload: str | None = None) -> None: """Generate a notification. Optionally, you can pass a payload with the notification. @@ -1445,13 +1479,15 @@ def notify(self, db=None, stop=False, payload=None): if self.listening: if not db: db = self.db + if not db: + return event = self.stop_event if stop else self.event - q = f'notify "{event}"' + cmd = f'notify "{event}"' if payload: - q += f", '{payload}'" - return db.query(q) + cmd += f", '{payload}'" + return db.query(cmd) - def __call__(self): + def __call__(self) -> None: """Invoke the notification handler. The handler is a loop that listens for notifications on the event @@ -1469,14 +1505,15 @@ def __call__(self): Note: If you run this loop in another thread, don't use the same database connection for database operations in the main thread. """ + if not self.db: + return self.listen() poll = self.timeout == 0 - if not poll: - rlist = [self.db.fileno()] - while self.listening: + rlist = [] if poll else [self.db.fileno()] + while self.db and self.listening: # noinspection PyUnboundLocalVariable if poll or select.select(rlist, [], [], self.timeout)[0]: - while self.listening: + while self.db and self.listening: notice = self.db.getnotify() if not notice: # no more messages break @@ -1503,9 +1540,9 @@ def __call__(self): class DB: """Wrapper class for the _pg connection type.""" - db = None # invalid fallback for underlying connection + db: Connection | None = None # invalid fallback for underlying connection - def __init__(self, *args, **kw): + def __init__(self, *args: Any, **kw: Any) -> None: """Create a new connection. You can pass either the connection parameters or an existing @@ -1535,10 +1572,10 @@ def __init__(self, *args, **kw): self.db = db self.dbname = db.db self._regtypes = False - self._attnames = {} - self._generated = {} - self._pkeys = {} - self._privileges = {} + self._attnames: dict[str, AttrDict] = {} + self._generated: dict[str, frozenset[str]] = {} + self._pkeys: dict[str, str | tuple[str, ...]] = {} + self._privileges: dict[tuple[str, str], bool] = {} self.adapter = Adapter(self) self.dbtypes = DbTypes(self) self._query_attnames = ( @@ -1566,9 +1603,9 @@ def __init__(self, *args, **kw): # * to a file object to write debug statements or # * to a callable object which takes a string argument # * to any other true value to just print debug statements - self.debug = None + self.debug: Any = None - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """Get the specified attritbute of the connection.""" # All undefined members are same as in underlying connection: if self.db: @@ -1576,7 +1613,7 @@ def __getattr__(self, name): else: raise _int_error('Connection is not valid') - def __dir__(self): + def __dir__(self) -> list[str]: """List all attributes of the connection.""" # Custom dir function including the attributes of the connection: attrs = set(self.__class__.__dict__) @@ -1586,19 +1623,20 @@ def __dir__(self): # Context manager methods - def __enter__(self): + def __enter__(self) -> DB: """Enter the runtime context. This will start a transaction.""" self.begin() return self - def __exit__(self, et, ev, tb): + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: """Exit the runtime context. This will end the transaction.""" if et is None and ev is None and tb is None: self.commit() else: self.rollback() - def __del__(self): + def __del__(self) -> None: """Delete the connection.""" try: db = self.db @@ -1613,7 +1651,7 @@ def __del__(self): # Auxiliary methods - def _do_debug(self, *args): + def _do_debug(self, *args: Any) -> None: """Print a debug message.""" if self.debug: s = '\n'.join(str(arg) for arg in args) @@ -1627,7 +1665,7 @@ def _do_debug(self, *args): else: print(s) - def _escape_qualified_name(self, s): + def _escape_qualified_name(self, s: str) -> str: """Escape a qualified name. Escapes the name for use as an SQL identifier, unless the @@ -1640,15 +1678,23 @@ def _escape_qualified_name(self, s): return s @staticmethod - def _make_bool(d): + def _make_bool(d: Any) -> bool | str: """Get boolean value corresponding to d.""" return bool(d) if get_bool() else ('t' if d else 'f') @staticmethod - def _list_params(params): + def _list_params(params: Sequence) -> str: """Create a human readable parameter list.""" return ', '.join(f'${n}={v!r}' for n, v in enumerate(params, 1)) + @property + def _valid_db(self) -> Connection: + """Get underlying connection and make sure it is not closed.""" + db = self.db + if not db: + raise _int_error('Connection already closed') + return db + # Public methods # escape_string and escape_bytea exist as methods, @@ -1656,46 +1702,38 @@ def _list_params(params): unescape_bytea = staticmethod(unescape_bytea) @staticmethod - def decode_json(s): + def decode_json(s: str) -> Any: """Decode a JSON string coming from the database.""" return (get_jsondecode() or jsondecode)(s) @staticmethod - def encode_json(d): + def encode_json(d: Any) -> str: """Encode a JSON string for use within SQL.""" return jsonencode(d) - def close(self): + def close(self) -> None: """Close the database connection.""" # Wraps shared library function so we can track state. - db = self.db - if db: - with suppress(TypeError): # when already closed - db.set_cast_hook(None) - if self._closeable: - db.close() - self.db = None - else: - raise _int_error('Connection already closed') + db = self._valid_db + with suppress(TypeError): # when already closed + db.set_cast_hook(None) + if self._closeable: + db.close() + self.db = None - def reset(self): + def reset(self) -> None: """Reset connection with current parameters. All derived queries and large objects derived from this connection will not be usable after this call. + """ + self._valid_db.reset() - """ - if self.db: - self.db.reset() - else: - raise _int_error('Connection already closed') - - def reopen(self): + def reopen(self) -> None: """Reopen connection to the database. Used in case we need another connection to the same database. Note that we can still reopen a database that we have closed. - """ # There is no such shared library function. if self._closeable: @@ -1708,7 +1746,7 @@ def reopen(self): else: self.db = self._db_args - def begin(self, mode=None): + def begin(self, mode: str | None = None) -> None: """Begin a transaction.""" qstr = 'BEGIN' if mode: @@ -1717,13 +1755,13 @@ def begin(self, mode=None): start = begin - def commit(self): + def commit(self) -> None: """Commit the current transaction.""" return self.query('COMMIT') end = commit - def rollback(self, name=None): + def rollback(self, name: str | None = None) -> None: """Roll back the current transaction.""" qstr = 'ROLLBACK' if name: @@ -1732,15 +1770,18 @@ def rollback(self, name=None): abort = rollback - def savepoint(self, name): + def savepoint(self, name: str) -> None: """Define a new savepoint within the current transaction.""" return self.query('SAVEPOINT ' + name) - def release(self, name): + def release(self, name: str) -> None: """Destroy a previously defined savepoint.""" return self.query('RELEASE ' + name) - def get_parameter(self, parameter): + def get_parameter(self, + parameter: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str] | dict[str, Any] + ) -> str | list[str] | dict[str, str]: """Get the value of a run-time parameter. If the parameter is a string, the return value will also be a string @@ -1757,6 +1798,7 @@ def get_parameter(self, parameter): By passing the special name 'all' as the parameter, you can get a dict of all existing configuration parameters. """ + values: Any if isinstance(parameter, str): parameter = [parameter] values = None @@ -1771,25 +1813,26 @@ def get_parameter(self, parameter): 'The parameter must be a string, list, set or dict') if not parameter: raise TypeError('No parameter has been specified') - params = {} if isinstance(values, dict) else [] - for key in parameter: - param = key.strip().lower() if isinstance( - key, (bytes, str)) else None + query = self._valid_db.query + params: Any = {} if isinstance(values, dict) else [] + for param_key in parameter: + param = param_key.strip().lower() if isinstance( + param_key, (bytes, str)) else None if not param: raise TypeError('Invalid parameter') if param == 'all': - q = 'SHOW ALL' - values = self.db.query(q).getresult() + cmd = 'SHOW ALL' + values = query(cmd).getresult() values = {value[0]: value[1] for value in values} break - if isinstance(values, dict): - params[param] = key + if isinstance(params, dict): + params[param] = param_key else: params.append(param) else: for param in params: - q = f'SHOW {param}' - value = self.db.query(q).singlescalar() + cmd = f'SHOW {param}' + value = query(cmd).singlescalar() if values is None: values = value elif isinstance(values, list): @@ -1798,7 +1841,12 @@ def get_parameter(self, parameter): values[params[param]] = value return values - def set_parameter(self, parameter, value=None, local=False): + def set_parameter(self, + parameter: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str] | dict[str, Any], + value: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str]| None = None, + local: bool = False) -> None: """Set the value of a run-time parameter. If the parameter and the value are strings, the run-time parameter @@ -1833,7 +1881,7 @@ def set_parameter(self, parameter, value=None, local=False): if isinstance(value, (list, tuple, set, frozenset)): value = set(value) if len(value) == 1: - value = value.pop() + value = next(iter(value)) if not (value is None or isinstance(value, str)): raise ValueError( 'A single value must be specified' @@ -1849,30 +1897,28 @@ def set_parameter(self, parameter, value=None, local=False): 'The parameter must be a string, list, set or dict') if not parameter: raise TypeError('No parameter has been specified') - params = {} - for key, value in parameter.items(): - param = key.strip().lower() if isinstance( - key, str) else None + params: dict[str, str | None] = {} + for param, param_value in parameter.items(): + param = param.strip().lower() if isinstance(param, str) else None if not param: raise TypeError('Invalid parameter') if param == 'all': - if value is not None: + if param_value is not None: raise ValueError( - 'A value must ot be specified' + 'A value must not be specified' " when parameter is 'all'") params = {'all': None} break - params[param] = value - local = ' LOCAL' if local else '' - for param, value in params.items(): - if value is None: - q = f'RESET{local} {param}' - else: - q = f'SET{local} {param} TO {value}' - self._do_debug(q) - self.db.query(q) - - def query(self, command, *args): + params[param] = param_value + local_clause = ' LOCAL' if local else '' + for param, param_value in params.items(): + cmd = (f'RESET{local_clause} {param}' + if param_value is None else + f'SET{local_clause} {param} TO {param_value}') + self._do_debug(cmd) + self._valid_db.query(cmd) + + def query(self, command: str, *args: Any) -> Query: """Execute a SQL command string. This method simply sends a SQL query to the database. If the query is @@ -1892,16 +1938,17 @@ def query(self, command, *args): values can also be given as a single list or tuple argument. """ # Wraps shared library function for debugging. - if not self.db: - raise _int_error('Connection is not valid') + db = self._valid_db if args: self._do_debug(command, args) - return self.db.query(command, args) + return db.query(command, args) self._do_debug(command) - return self.db.query(command) + return db.query(command) - def query_formatted(self, command, - parameters=None, types=None, inline=False): + def query_formatted(self, command: str, + parameters: tuple | list | dict | None = None, + types: tuple | list | dict | None = None, + inline: bool =False) -> Query: """Execute a formatted SQL command string. Similar to query, but using Python format placeholders of the form @@ -1916,24 +1963,23 @@ def query_formatted(self, command, return self.query(*self.adapter.format_query( command, parameters, types, inline)) - def query_prepared(self, name, *args): + def query_prepared(self, name: str, *args: Any) -> Query: """Execute a prepared SQL statement. This works like the query() method, except that instead of passing the SQL command, you pass the name of a prepared statement. If you pass an empty name, the unnamed statement will be executed. """ - if not self.db: - raise _int_error('Connection is not valid') if name is None: name = '' + db = self._valid_db if args: self._do_debug('EXECUTE', name, args) - return self.db.query_prepared(name, args) + return db.query_prepared(name, args) self._do_debug('EXECUTE', name) - return self.db.query_prepared(name) + return db.query_prepared(name) - def prepare(self, name, command): + def prepare(self, name: str, command: str) -> Query: """Create a prepared SQL statement. This creates a prepared statement for the given command with the @@ -1946,14 +1992,12 @@ def prepare(self, name, command): named queries, since unnamed queries have a limited lifetime and can be automatically replaced or destroyed by various operations. """ - if not self.db: - raise _int_error('Connection is not valid') if name is None: name = '' self._do_debug('prepare', name, command) - return self.db.prepare(name, command) + return self._valid_db.prepare(name, command) - def describe_prepared(self, name=None): + def describe_prepared(self, name: str | None = None) -> Query: """Describe a prepared SQL statement. This method returns a Query object describing the result columns of @@ -1962,9 +2006,9 @@ def describe_prepared(self, name=None): """ if name is None: name = '' - return self.db.describe_prepared(name) + return self._valid_db.describe_prepared(name) - def delete_prepared(self, name=None): + def delete_prepared(self, name: str | None = None) -> Query: """Delete a prepared SQL statement. This deallocates a previously prepared SQL statement with the given @@ -1974,12 +2018,13 @@ def delete_prepared(self, name=None): """ if not name: name = 'ALL' - q = f"DEALLOCATE {name}" - self._do_debug(q) - return self.db.query(q) + cmd = f"DEALLOCATE {name}" + self._do_debug(cmd) + return self._valid_db.query(cmd) - def pkey(self, table, composite=False, flush=False): - """Get or set the primary key of a table. + def pkey(self, table: str, composite: bool = False, flush: bool = False + ) -> str | tuple[str, ...]: + """Get the primary key of a table. Single primary keys are returned as strings unless you set the composite flag. Composite primary keys are always @@ -1997,26 +2042,26 @@ def pkey(self, table, composite=False, flush=False): try: # cache lookup pkey = pkeys[table] except KeyError as e: # cache miss, check the database - q = ("SELECT" # noqa: S608 - " a.attname, a.attnum, i.indkey" - " FROM pg_catalog.pg_index i" - " JOIN pg_catalog.pg_attribute a" - " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" - " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" - " AND NOT a.attisdropped" - " WHERE i.indrelid OPERATOR(pg_catalog.=)" - " {}::pg_catalog.regclass" - " AND i.indisprimary ORDER BY a.attnum").format( - _quote_if_unqualified('$1', table)) - pkey = self.db.query(q, (table,)).getresult() + cmd = ("SELECT" # noqa: S608 + " a.attname, a.attnum, i.indkey" + " FROM pg_catalog.pg_index i" + " JOIN pg_catalog.pg_attribute a" + " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" + " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" + " AND NOT a.attisdropped" + " WHERE i.indrelid OPERATOR(pg_catalog.=)" + " {}::pg_catalog.regclass" + " AND i.indisprimary ORDER BY a.attnum").format( + _quote_if_unqualified('$1', table)) + pkey = self._valid_db.query(cmd, (table,)).getresult() if not pkey: raise KeyError(f'Table {table} has no primary key') from e # we want to use the order defined in the primary key index here, # not the order as defined by the columns in the table if len(pkey) > 1: indkey = pkey[0][2] - pkey = sorted(pkey, key=lambda row: indkey.index(row[1])) - pkey = tuple(row[0] for row in pkey) + pkey = tuple(row[0] for row in sorted( + pkey, key=lambda row: indkey.index(row[1]))) else: pkey = pkey[0][0] pkeys[table] = pkey # cache it @@ -2024,12 +2069,20 @@ def pkey(self, table, composite=False, flush=False): pkey = (pkey,) return pkey - def get_databases(self): + def pkeys(self, table: str) -> tuple[str, ...]: + """Get the primary key of a table as a tuple. + + Same as pkey() with 'composite' set to True. + """ + return self.pkey(table, True) # type: ignore + + def get_databases(self) -> list[str]: """Get list of databases in the system.""" - return [s[0] for s in self.db.query( + return [r[0] for r in self._valid_db.query( 'SELECT datname FROM pg_catalog.pg_database').getresult()] - def get_relations(self, kinds=None, system=False): + def get_relations(self, kinds: str | Sequence[str] | None = None, + system: bool = False) -> list[str]: """Get list of relations in connected database of specified kinds. If kinds is None or empty, all kinds of relations are returned. @@ -2038,31 +2091,32 @@ def get_relations(self, kinds=None, system=False): Set the system flag if you want to get the system relations as well. """ - where = [] + where_parts = [] if kinds: - where.append( + where_parts.append( "r.relkind IN ({})".format(','.join(f"'{k}'" for k in kinds))) if not system: - where.append("s.nspname NOT SIMILAR" - " TO 'pg/_%|information/_schema' ESCAPE '/'") - where = " WHERE " + ' AND '.join(where) if where else '' - q = ("SELECT" # noqa: S608 - " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" - " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" - " FROM pg_catalog.pg_class r" - " JOIN pg_catalog.pg_namespace s" - f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" - " ORDER BY s.nspname, r.relname") - return [r[0] for r in self.db.query(q).getresult()] - - def get_tables(self, system=False): + where_parts.append("s.nspname NOT SIMILAR" + " TO 'pg/_%|information/_schema' ESCAPE '/'") + where = " WHERE " + ' AND '.join(where_parts) if where_parts else '' + cmd = ("SELECT" # noqa: S608 + " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" + " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" + " FROM pg_catalog.pg_class r" + " JOIN pg_catalog.pg_namespace s" + f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" + " ORDER BY s.nspname, r.relname") + return [r[0] for r in self._valid_db.query(cmd).getresult()] + + def get_tables(self, system: bool = False) -> list[str]: """Return list of tables in connected database. Set the system flag if you want to get the system tables as well. """ return self.get_relations('r', system) - def get_attnames(self, table, with_oid=True, flush=False): + def get_attnames(self, table: str, with_oid: bool=True, flush: bool=False + ) -> AttrDict: """Given the name of a table, dig out the set of attribute names. Returns a read-only dictionary of attribute names (the names are @@ -2083,19 +2137,18 @@ def get_attnames(self, table, with_oid=True, flush=False): try: # cache lookup names = attnames[table] except KeyError: # cache miss, check the database - q = "a.attnum OPERATOR(pg_catalog.>) 0" + cmd = "a.attnum OPERATOR(pg_catalog.>) 0" if with_oid: - q = f"({q} OR a.attname OPERATOR(pg_catalog.=) 'oid')" - q = self._query_attnames.format( - _quote_if_unqualified('$1', table), q) - names = self.db.query(q, (table,)).getresult() + cmd = f"({cmd} OR a.attname OPERATOR(pg_catalog.=) 'oid')" + cmd = self._query_attnames.format( + _quote_if_unqualified('$1', table), cmd) + names = self._valid_db.query(cmd, (table,)).getresult() types = self.dbtypes - names = ((name[0], types.add(*name[1:])) for name in names) - names = AttrDict(names) + names = AttrDict((name[0], types.add(*name[1:])) for name in names) attnames[table] = names # cache it return names - def get_generated(self, table, flush=False): + def get_generated(self, table: str, flush: bool = False) -> frozenset[str]: """Given the name of a table, dig out the set of generated columns. Returns a set of column names that are generated and unalterable. @@ -2111,28 +2164,28 @@ def get_generated(self, table, flush=False): try: # cache lookup names = generated[table] except KeyError: # cache miss, check the database - q = "a.attnum OPERATOR(pg_catalog.>) 0" - q = f"{q} AND {self._query_generated}" - q = self._query_attnames.format( - _quote_if_unqualified('$1', table), q) - names = self.db.query(q, (table,)).getresult() + cmd = "a.attnum OPERATOR(pg_catalog.>) 0" + cmd = f"{cmd} AND {self._query_generated}" + cmd = self._query_attnames.format( + _quote_if_unqualified('$1', table), cmd) + names = self._valid_db.query(cmd, (table,)).getresult() names = frozenset(name[0] for name in names) generated[table] = names # cache it return names - def use_regtypes(self, regtypes=None): + def use_regtypes(self, regtypes: bool | None = None) -> bool: """Use registered type names instead of simplified type names.""" if regtypes is None: return self.dbtypes._regtypes - else: - regtypes = bool(regtypes) - if regtypes != self.dbtypes._regtypes: - self.dbtypes._regtypes = regtypes - self._attnames.clear() - self.dbtypes.clear() - return regtypes - - def has_table_privilege(self, table, privilege='select', flush=False): + regtypes = bool(regtypes) + if regtypes != self.dbtypes._regtypes: + self.dbtypes._regtypes = regtypes + self._attnames.clear() + self.dbtypes.clear() + return regtypes + + def has_table_privilege(self, table: str, privilege: str = 'select', + flush: bool = False) -> bool: """Check whether current user has specified table privilege. If flush is set, then the internal cache for table privileges will @@ -2146,14 +2199,15 @@ def has_table_privilege(self, table, privilege='select', flush=False): try: # ask cache ret = privileges[table, privilege] except KeyError: # cache miss, ask the database - q = "SELECT pg_catalog.has_table_privilege({}, $2)".format( + cmd = "SELECT pg_catalog.has_table_privilege({}, $2)".format( _quote_if_unqualified('$1', table)) - q = self.db.query(q, (table, privilege)) - ret = q.singlescalar() == self._make_bool(True) + query = self._valid_db.query(cmd, (table, privilege)) + ret = query.singlescalar() == self._make_bool(True) privileges[table, privilege] = ret # cache it return ret - def get(self, table, row, keyname=None): + def get(self, table: str, row: Any, + keyname: str | tuple[str, ...] | None = None) -> dict[str, Any]: """Get a row from a database table or view. This method is the basic mechanism to get a single row. It assumes @@ -2181,7 +2235,7 @@ def get(self, table, row, keyname=None): row['oid'] = row[qoid] if not keyname: try: # if keyname is not specified, try using the primary key - keyname = self.pkey(table, True) + keyname = self.pkeys(table) except KeyError as e: # the table has no primary key # try using the oid instead if qoid and isinstance(row, dict) and 'oid' in row: @@ -2216,10 +2270,10 @@ def get(self, table, row, keyname=None): row[qoid] = row['oid'] del row['oid'] t = self._escape_qualified_name(table) - q = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() + cmd = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() if not res: # make where clause in error message better readable where = where.replace('OPERATOR(pg_catalog.=)', '=') @@ -2232,7 +2286,8 @@ def get(self, table, row, keyname=None): row[n] = value return row - def insert(self, table, row=None, **kw): + def insert(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> dict[str, Any]: """Insert a row into a database table. This method inserts a row into a table. The name of the table must @@ -2258,21 +2313,21 @@ def insert(self, table, row=None, **kw): params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - names, values = [], [] + name_list, value_list = [], [] for n in attnames: if n in row and n not in generated: - names.append(col(n)) - values.append(adapt(row[n], attnames[n])) - if not names: + name_list.append(col(n)) + value_list.append(adapt(row[n], attnames[n])) + if not name_list: raise _prg_error('No column found that can be inserted') - names, values = ', '.join(names), ', '.join(values) + names, values = ', '.join(name_list), ', '.join(value_list) ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = (f'INSERT INTO {t} ({names})' # noqa: S608 - f' VALUES ({values}) RETURNING {ret}') - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() + cmd = (f'INSERT INTO {t} ({names})' # noqa: S608 + f' VALUES ({values}) RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() if res: # this should always be true for n, value in res[0].items(): if qoid and n == 'oid': @@ -2280,7 +2335,8 @@ def insert(self, table, row=None, **kw): row[n] = value return row - def update(self, table, row=None, **kw): + def update(self, table: str, row: dict[str, Any] | None = None, **kw : Any + ) -> dict[str, Any]: """Update an existing row in a database table. Similar to insert, but updates an existing row. The update is based @@ -2304,39 +2360,40 @@ def update(self, table, row=None, **kw): if qoid and qoid in row and 'oid' not in row: row['oid'] = row[qoid] if qoid and 'oid' in row: # try using the oid - keyname = ('oid',) + keynames: tuple[str, ...] = ('oid',) + keyset = set(keynames) else: # try using the primary key try: - keyname = self.pkey(table, True) + keynames = self.pkeys(table) except KeyError as e: # the table has no primary key raise _prg_error(f'Table {table} has no primary key') from e + keyset = set(keynames) # check whether all key columns have values - if not set(keyname).issubset(row): + if not keyset.issubset(row): raise KeyError('Missing value for primary key in row') params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keyname) + col(k), adapt(row[k], attnames[k])) for k in keynames) if 'oid' in row: if qoid: row[qoid] = row['oid'] del row['oid'] - values = [] - keyname = set(keyname) + values_list = [] for n in attnames: - if n in row and n not in keyname and n not in generated: - values.append(f'{col(n)} = {adapt(row[n], attnames[n])}') - if not values: + if n in row and n not in keyset and n not in generated: + values_list.append(f'{col(n)} = {adapt(row[n], attnames[n])}') + if not values_list: return row - values = ', '.join(values) + values = ', '.join(values_list) ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = (f'UPDATE {t} SET {values}' # noqa: S608 - f' WHERE {where} RETURNING {ret}') - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() + cmd = (f'UPDATE {t} SET {values}' # noqa: S608 + f' WHERE {where} RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() if res: # may be empty when row does not exist for n, value in res[0].items(): if qoid and n == 'oid': @@ -2344,7 +2401,8 @@ def update(self, table, row=None, **kw): row[n] = value return row - def upsert(self, table, row=None, **kw): + def upsert(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> dict[str, Any]: """Insert a row into a database table with conflict resolution. This method inserts a row into a table, but instead of raising a @@ -2402,22 +2460,22 @@ def upsert(self, table, row=None, **kw): params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - names, values = [], [] + name_list, value_list = [], [] for n in attnames: if n in row and n not in generated: - names.append(col(n)) - values.append(adapt(row[n], attnames[n])) - names, values = ', '.join(names), ', '.join(values) + name_list.append(col(n)) + value_list.append(adapt(row[n], attnames[n])) + names, values = ', '.join(name_list), ', '.join(value_list) try: - keyname = self.pkey(table, True) + keynames = self.pkeys(table) except KeyError as e: raise _prg_error(f'Table {table} has no primary key') from e - target = ', '.join(col(k) for k in keyname) + target = ', '.join(col(k) for k in keynames) update = [] - keyname = set(keyname) - keyname.add('oid') + keyset = set(keynames) + keyset.add('oid') for n in attnames: - if n not in keyname and n not in generated: + if n not in keyset and n not in generated: value = kw.get(n, n in row) if value: if not isinstance(value, str): @@ -2428,12 +2486,12 @@ def upsert(self, table, row=None, **kw): do = 'update set ' + ', '.join(update) if update else 'nothing' ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 - f' VALUES ({values})' - f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() + cmd = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 + f' VALUES ({values})' + f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() if res: # may be empty with "do nothing" for n, value in res[0].items(): if qoid and n == 'oid': @@ -2443,7 +2501,8 @@ def upsert(self, table, row=None, **kw): self.get(table, row) return row - def clear(self, table, row=None): + def clear(self, table: str, row: dict[str, Any] | None = None + ) -> dict[str, Any]: """Clear all the attributes to values determined by the types. Numeric types are set to 0, Booleans are set to false, and everything @@ -2467,7 +2526,8 @@ def clear(self, table, row=None): row[n] = '' return row - def delete(self, table, row=None, **kw): + def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> int: """Delete an existing row in a database table. This method deletes the row from a table. It deletes based on the @@ -2492,31 +2552,33 @@ def delete(self, table, row=None, **kw): if qoid and qoid in row and 'oid' not in row: row['oid'] = row[qoid] if qoid and 'oid' in row: # try using the oid - keyname = ('oid',) + keynames: tuple[str, ...] = ('oid',) else: # try using the primary key try: - keyname = self.pkey(table, True) + keynames = self.pkeys(table) except KeyError as e: # the table has no primary key raise _prg_error(f'Table {table} has no primary key') from e # check whether all key columns have values - if not set(keyname).issubset(row): + if not set(keynames).issubset(row): raise KeyError('Missing value for primary key in row') params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keyname) + col(k), adapt(row[k], attnames[k])) for k in keynames) if 'oid' in row: if qoid: row[qoid] = row['oid'] del row['oid'] t = self._escape_qualified_name(table) - q = f'DELETE FROM {t} WHERE {where}' # noqa: S608 - self._do_debug(q, params) - res = self.db.query(q, params) + cmd = f'DELETE FROM {t} WHERE {where}' # noqa: S608 + self._do_debug(cmd, params) + res = self._valid_db.query(cmd, params) return int(res) - def truncate(self, table, restart=False, cascade=False, only=False): + def truncate(self, table: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str], restart: bool = False, + cascade: bool = False, only: bool = False) -> Query: """Empty a table or set of tables. This method quickly removes all rows from the given table or set @@ -2528,21 +2590,21 @@ def truncate(self, table, restart=False, cascade=False, only=False): If restart is set to True, sequences owned by columns of the truncated table(s) are automatically restarted. If cascade is set to True, it also truncates all tables that have foreign-key references to any of - the named tables. If the parameter only is not set to True, all the + the named tables. If the parameter 'only' is not set to True, all the descendant tables (if any) will also be truncated. Optionally, a '*' can be specified after the table name to explicitly indicate that descendant tables are included. """ if isinstance(table, str): - only = {table: only} + table_only = {table: only} table = [table] elif isinstance(table, (list, tuple)): if isinstance(only, (list, tuple)): - only = dict(zip(table, only)) + table_only = dict(zip(table, only)) else: - only = dict.fromkeys(table, only) + table_only = dict.fromkeys(table, only) elif isinstance(table, (set, frozenset)): - only = dict.fromkeys(table, only) + table_only = dict.fromkeys(table, only) else: raise TypeError('The table must be a string, list or set') if not (restart is None or isinstance(restart, (bool, int))): @@ -2551,7 +2613,7 @@ def truncate(self, table, restart=False, cascade=False, only=False): raise TypeError('Invalid type for the cascade option') tables = [] for t in table: - u = only.get(t) + u = table_only.get(t) if not (u is None or isinstance(u, (bool, int))): raise TypeError('Invalid type for the only option') if t.endswith('*'): @@ -2563,17 +2625,21 @@ def truncate(self, table, restart=False, cascade=False, only=False): if u: t = f'ONLY {t}' tables.append(t) - q = ['TRUNCATE', ', '.join(tables)] + cmd_parts = ['TRUNCATE', ', '.join(tables)] if restart: - q.append('RESTART IDENTITY') + cmd_parts.append('RESTART IDENTITY') if cascade: - q.append('CASCADE') - q = ' '.join(q) - self._do_debug(q) - return self.db.query(q) - - def get_as_list(self, table, what=None, where=None, - order=None, limit=None, offset=None, scalar=False): + cmd_parts.append('CASCADE') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + return self._valid_db.query(cmd) + + def get_as_list(self, table: str, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> list: """Get a table as a list. This gets a convenient representation of the table as a list @@ -2585,16 +2651,18 @@ def get_as_list(self, table, what=None, where=None, The parameter 'what' can restrict the query to only return a subset of the table columns. It can be a string, list or a tuple. + The parameter 'where' can restrict the query to only return a subset of the table rows. It can be a string, list or a tuple - of SQL expressions that all need to be fulfilled. The parameter - 'order' specifies the ordering of the rows. It can also be a - other string, list or a tuple. If no ordering is specified, - the result will be ordered by the primary key(s) or all columns - if no primary key exists. You can set 'order' to False if you - don't care about the ordering. The parameters 'limit' and 'offset' - can be integers specifying the maximum number of rows returned - and a number of rows skipped over. + of SQL expressions that all need to be fulfilled. + + The parameter 'order' specifies the ordering of the rows. It can + also be a string, list or a tuple. If no ordering is specified, + the result will be ordered by the primary key(s) or all columns if + no primary key exists. You can set 'order' to False if you don't + care about the ordering. The parameters 'limit' and 'offset' can be + integers specifying the maximum number of rows returned and a number + of rows skipped over. If you set the 'scalar' option to True, then instead of the named tuples you will get the first items of these tuples. @@ -2609,35 +2677,40 @@ def get_as_list(self, table, what=None, where=None, order = what else: what = '*' - q = ['SELECT', what, 'FROM', table] + cmd_parts = ['SELECT', what, 'FROM', table] if where: if isinstance(where, (list, tuple)): where = ' AND '.join(map(str, where)) - q.extend(['WHERE', where]) + cmd_parts.extend(['WHERE', where]) if order is None: try: - order = self.pkey(table, True) + order = self.pkeys(table) except (KeyError, ProgrammingError): with suppress(KeyError, ProgrammingError): order = list(self.get_attnames(table)) if order: if isinstance(order, (list, tuple)): order = ', '.join(map(str, order)) - q.extend(['ORDER BY', order]) + cmd_parts.extend(['ORDER BY', order]) if limit: - q.append(f'LIMIT {limit}') + cmd_parts.append(f'LIMIT {limit}') if offset: - q.append(f'OFFSET {offset}') - q = ' '.join(q) - self._do_debug(q) - q = self.db.query(q) - res = q.namedresult() + cmd_parts.append(f'OFFSET {offset}') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + query = self._valid_db.query(cmd) + res = query.namedresult() if res and scalar: res = [row[0] for row in res] return res - def get_as_dict(self, table, keyname=None, what=None, where=None, - order=None, limit=None, offset=None, scalar=False): + def get_as_dict(self, table: str, + keyname: str | list[str] | tuple[str, ...] | None = None, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> dict: """Get a table as a dictionary. This method is similar to get_as_list(), but returns the table @@ -2652,7 +2725,7 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, be set as a string, list or a tuple. If the Python version supports it, the dictionary will be an - OrderedDict using the order specified with the 'order' parameter + dict using the order specified with the 'order' parameter or the key column(s) if not specified. You can set 'order' to False if you don't care about the ordering. In this case the returned dictionary will be an ordinary one. @@ -2661,12 +2734,14 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, raise TypeError('The table name is missing') if not keyname: try: - keyname = self.pkey(table, True) + keyname = self.pkeys(table) except (KeyError, ProgrammingError) as e: raise _prg_error(f'Table {table} has no primary key') from e if isinstance(keyname, str): - keyname = [keyname] - elif not isinstance(keyname, (list, tuple)): + keynames: list[str] | tuple[str, ...] = (keyname,) + elif isinstance(keyname, (list, tuple)): + keynames = keyname + else: raise KeyError('The keyname must be a string, list or tuple') if what: if isinstance(what, (list, tuple)): @@ -2675,64 +2750,68 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, order = what else: what = '*' - q = ['SELECT', what, 'FROM', table] + cmd_parts = ['SELECT', what, 'FROM', table] if where: if isinstance(where, (list, tuple)): where = ' AND '.join(map(str, where)) - q.extend(['WHERE', where]) + cmd_parts.extend(['WHERE', where]) if order is None: order = keyname if order: if isinstance(order, (list, tuple)): order = ', '.join(map(str, order)) - q.extend(['ORDER BY', order]) + cmd_parts.extend(['ORDER BY', order]) if limit: - q.append(f'LIMIT {limit}') + cmd_parts.append(f'LIMIT {limit}') if offset: - q.append(f'OFFSET {offset}') - q = ' '.join(q) - self._do_debug(q) - q = self.db.query(q) - res = q.getresult() - cls = OrderedDict if order else dict + cmd_parts.append(f'OFFSET {offset}') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + query = self._valid_db.query(cmd) + res = query.getresult() if not res: - return cls() - keyset = set(keyname) - fields = q.listfields() + return {} + keyset = set(keynames) + fields = query.listfields() if not keyset.issubset(fields): raise KeyError('Missing keyname in row') - keyind, rowind = [], [] + key_index: list[int] = [] + row_index: list[int] = [] for i, f in enumerate(fields): - (keyind if f in keyset else rowind).append(i) - keytuple = len(keyind) > 1 - getkey = itemgetter(*keyind) - keys = map(getkey, res) + (key_index if f in keyset else row_index).append(i) + key_tuple = len(key_index) > 1 + get_key = itemgetter(*key_index) + keys = map(get_key, res) if scalar: - rowind = rowind[:1] - rowtuple = False + row_index = row_index[:1] + row_is_tuple = False else: - rowtuple = len(rowind) > 1 - if scalar or rowtuple: - getrow = itemgetter(*rowind) + row_is_tuple = len(row_index) > 1 + if scalar or row_is_tuple: + get_row: Callable[[tuple], tuple] = itemgetter( # pyright: ignore + *row_index) else: - rowind = rowind[0] + frst_index = row_index[0] - def getrow(row): - return row[rowind], # tuple with one item + def get_row(row : tuple) -> tuple: + return row[frst_index], # tuple with one item - rowtuple = True - rows = map(getrow, res) - if keytuple or rowtuple: - if keytuple: - keys = _namediter(_MemoryQuery(keys, keyname)) - if rowtuple: + row_is_tuple = True + rows = map(get_row, res) + if key_tuple or row_is_tuple: + if key_tuple: + keys = _namediter(_MemoryQuery(keys, keynames)) # type: ignore + if row_is_tuple: fields = [f for f in fields if f not in keyset] - rows = _namediter(_MemoryQuery(rows, fields)) + rows = _namediter(_MemoryQuery(rows, fields)) # type: ignore # noinspection PyArgumentList - return cls(zip(keys, rows)) + return dict(zip(keys, rows)) - def notification_handler(self, event, callback, - arg_dict=None, timeout=None, stop_event=None): + def notification_handler(self, event: str, callback: Callable, + arg_dict: dict | None = None, + timeout: int | float | None = None, + stop_event: str | None = None + ) -> NotificationHandler: """Get notification handler that will run the given callback.""" return NotificationHandler(self, event, callback, arg_dict, timeout, stop_event) diff --git a/pgdb.py b/pgdb.py index 2e48e39d..df23bbfd 100644 --- a/pgdb.py +++ b/pgdb.py @@ -64,6 +64,8 @@ connection.close() # close the connection """ +from __future__ import annotations + from collections import namedtuple from collections.abc import Iterable from contextlib import suppress @@ -76,7 +78,7 @@ from math import isinf, isnan from re import compile as regex from time import localtime -from typing import ClassVar, Dict, Type +from typing import Callable, ClassVar from uuid import UUID as Uuid # noqa: N811 try: @@ -91,15 +93,16 @@ if os.path.exists(os.path.join(path, libpq))] if sys.version_info >= (3, 8): # see https://docs.python.org/3/whatsnew/3.8.html#ctypes + add_dll_dir = os.add_dll_directory # type: ignore for path in paths: - with os.add_dll_directory(os.path.abspath(path)): + with add_dll_dir(os.path.abspath(path)): try: - from _pg import version + from _pg import version # type: ignore except ImportError: pass else: del version - e = None + e = None # type: ignore break if paths: libpq = 'compatible ' + libpq @@ -140,7 +143,7 @@ 'Date', 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks', 'Binary', 'Interval', 'Uuid', - 'Hstore', 'Json', 'Literal', 'Type', + 'Hstore', 'Json', 'Literal', 'DbType', 'STRING', 'BINARY', 'NUMBER', 'DATETIME', 'ROWID', 'BOOL', 'SMALLINT', 'INTEGER', 'LONG', 'FLOAT', 'NUMERIC', 'MONEY', 'DATE', 'TIME', 'TIMESTAMP', 'INTERVAL', @@ -150,9 +153,10 @@ 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', 'apilevel', 'connect', 'paramstyle', 'threadsafety', 'get_typecast', 'set_typecast', 'reset_typecast', - 'version', '__version__'] + 'version', '__version__', +] -Decimal = StdDecimal +Decimal: type = StdDecimal # *** Module Constants *** @@ -173,17 +177,19 @@ # *** Internal Type Handling *** -def get_args(func): +def get_args(func: Callable) -> list: return list(signature(func).parameters) # time zones used in Postgres timestamptz output -_timezones = dict(CET='+0100', EET='+0200', EST='-0500', - GMT='+0000', HST='-1000', MET='+0100', MST='-0700', - UCT='+0000', UTC='+0000', WET='+0000') +_timezones: dict[str, str] = { + 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', + 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', + 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' +} -def _timezone_as_offset(tz): +def _timezone_as_offset(tz: str) -> str: if tz.startswith(('+', '-')): if len(tz) < 5: return tz + '00' @@ -191,7 +197,7 @@ def _timezone_as_offset(tz): return _timezones.get(tz, '+0000') -def decimal_type(decimal_type=None): +def decimal_type(decimal_type: type | None = None): """Get or set global type to be used for decimal values. Note that connections cache cast functions. To be sure a global change @@ -204,25 +210,25 @@ def decimal_type(decimal_type=None): return Decimal -def cast_bool(value): +def cast_bool(value: str) -> bool | None: """Cast boolean value in database format to bool.""" if value: return value[0] in ('t', 'T') -def cast_money(value): +def cast_money(value: str) -> Decimal | None: # pyright: ignore """Cast money value in database format to Decimal.""" if value: value = value.replace('(', '-') return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) -def cast_int2vector(value): +def cast_int2vector(value: str) -> list[int]: """Cast an int2vector value.""" return [int(v) for v in value.split()] -def cast_date(value, connection): +def cast_date(value: str, connection) -> date: """Cast a date value.""" # The output format depends on the server setting DateStyle. The default # setting ISO and the setting for German are actually unambiguous. The @@ -232,17 +238,17 @@ def cast_date(value, connection): return date.min if value == 'infinity': return date.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return date.min - value = value[0] + value = values[0] if len(value) > 10: return date.max - fmt = connection.date_format() - return datetime.strptime(value, fmt).date() + format = connection.date_format() + return datetime.strptime(value, format).date() -def cast_time(value): +def cast_time(value: str) -> time: """Cast a time value.""" fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' return datetime.strptime(value, fmt).time() @@ -251,74 +257,74 @@ def cast_time(value): _re_timezone = regex('(.*)([+-].*)') -def cast_timetz(value): +def cast_timetz(value: str) -> time: """Cast a timetz value.""" - tz = _re_timezone.match(value) - if tz: - value, tz = tz.groups() + m = _re_timezone.match(value) + if m: + value, tz = m.groups() else: tz = '+0000' - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' value += _timezone_as_offset(tz) - fmt += '%z' - return datetime.strptime(value, fmt).timetz() + format += '%z' + return datetime.strptime(value, format).timetz() -def cast_timestamp(value, connection): +def cast_timestamp(value: str, connection) -> datetime: """Cast a timestamp value.""" if value == '-infinity': return datetime.min if value == 'infinity': return datetime.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:5] - if len(value[3]) > 4: + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:5] + if len(values[3]) > 4: return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] else: - if len(value[0]) > 10: + if len(values[0]) > 10: return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - return datetime.strptime(' '.join(value), ' '.join(fmt)) + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + return datetime.strptime(' '.join(values), ' '.join(formats)) -def cast_timestamptz(value, connection): +def cast_timestamptz(value: str, connection) -> datetime: """Cast a timestamptz value.""" if value == '-infinity': return datetime.min if value == 'infinity': return datetime.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:] - if len(value[3]) > 4: + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:] + if len(values[3]) > 4: return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] - value, tz = value[:-1], value[-1] + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + values, tz = values[:-1], values[-1] else: - if fmt.startswith('%Y-'): - tz = _re_timezone.match(value[1]) - if tz: - value[1], tz = tz.groups() + if format.startswith('%Y-'): + m = _re_timezone.match(values[1]) + if m: + values[1], tz = m.groups() else: tz = '+0000' else: - value, tz = value[:-1], value[-1] - if len(value[0]) > 10: + values, tz = values[:-1], values[-1] + if len(values[0]) > 10: return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - value.append(_timezone_as_offset(tz)) - fmt.append('%z') - return datetime.strptime(' '.join(value), ' '.join(fmt)) + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + values.append(_timezone_as_offset(tz)) + formats.append('%z') + return datetime.strptime(' '.join(values), ' '.join(formats)) _re_interval_sql_standard = regex( @@ -349,37 +355,37 @@ def cast_timestamptz(value, connection): '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') -def cast_interval(value): +def cast_interval(value: str) -> timedelta: """Cast an interval value.""" # The output format depends on the server setting IntervalStyle, but it's # not necessary to consult this setting to parse it. It's faster to just # check all possible formats, and there is no ambiguity here. m = _re_interval_iso_8601.match(value) if m: - m = [d or '0' for d in m.groups()] - secs_ago = m.pop(5) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + secs_ago = s.pop(5) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if secs_ago: secs = -secs usecs = -usecs else: m = _re_interval_postgres_verbose.match(value) if m: - m, ago = [d or '0' for d in m.groups()[:8]], m.group(9) - secs_ago = m.pop(5) == '-' - m = [-int(d) for d in m] if ago else [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) + secs_ago = s.pop(5) == '-' + d = [-int(v) for v in s] if ago else [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if secs_ago: secs = - secs usecs = -usecs else: m = _re_interval_postgres.match(value) if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if hours_ago: hours = -hours mins = -mins @@ -388,11 +394,11 @@ def cast_interval(value): else: m = _re_interval_sql_standard.match(value) if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - years_ago = m.pop(0) == '-' - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + years_ago = s.pop(0) == '-' + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if years_ago: years = -years mons = -mons @@ -419,7 +425,7 @@ class Typecasts(dict): # the default cast functions # (str functions are ignored but have been added for faster access) - defaults: ClassVar[Dict[str, Type]] = { + defaults: ClassVar[dict[str, Callable]] = { 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, @@ -759,10 +765,6 @@ def _op_error(msg): # *** Row Tuples *** - -_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$') - - # The result rows for database operations are returned as named tuples # by default. Since creating namedtuple classes is a somewhat expensive # operation, we cache up to 1024 of these classes by default. diff --git a/pgmodule.c b/pgmodule.c index 628de9ec..64e769f6 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -21,7 +21,7 @@ static PyObject *Error, *Warning, *InterfaceError, *DatabaseError, *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, *DataError, *NotSupportedError, *InvalidResultError, *NoResultError, - *MultipleResultsError; + *MultipleResultsError, *Connection, *Query; #define _TOSTRING(x) #x #define TOSTRING(x) _TOSTRING(x) @@ -1305,6 +1305,12 @@ PyInit__pg(void) InvalidResultError, NULL); PyDict_SetItemString(dict, "MultipleResultsError", MultipleResultsError); + /* Types */ + Connection = (PyObject *)&connType; + PyDict_SetItemString(dict, "Connection", Connection); + Query = (PyObject *)&queryType; + PyDict_SetItemString(dict, "Query", Query); + /* Make the version available */ s = PyUnicode_FromString(PyPgVersion); PyDict_SetItemString(dict, "version", s); diff --git a/pyproject.toml b/pyproject.toml index 131308b8..1016b433 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,21 @@ exclude = [ [tool.ruff.per-file-ignores] "tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] +[tool.mypy] +python_version = "3.11" +check_untyped_defs = true +no_implicit_optional = true +strict_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +disallow_untyped_defs = true + +[[tool.mypy.overrides]] +module = [ + "tests.*" +] +disallow_untyped_defs = false + [tool.setuptools] py-modules = ["pg", "pgdb"] license-files = ["LICENSE.txt"] diff --git a/tests/config.py b/tests/config.py index f6280548..0b15f62e 100644 --- a/tests/config.py +++ b/tests/config.py @@ -26,9 +26,9 @@ dbport = int(dbport) try: - from .LOCAL_PyGreSQL import * # noqa: F403 + from .LOCAL_PyGreSQL import * # type: ignore # noqa except (ImportError, ValueError): - try: # noqa: SIM105 - from LOCAL_PyGreSQL import * # noqa: F403 + try: # noqa + from LOCAL_PyGreSQL import * # type: ignore # noqa except ImportError: pass diff --git a/tests/dbapi20.py b/tests/dbapi20.py index d5f2938f..f72d99f7 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -7,13 +7,14 @@ Some modernization of the code has been done by the PyGreSQL team. """ -__version__ = '1.15.0' +from __future__ import annotations import time import unittest from contextlib import suppress -from typing import Any, Mapping, Tuple +from typing import Any, Mapping +__version__ = '1.15.0' class DatabaseAPI20Test(unittest.TestCase): """Test a database self.driver for DB API 2.0 compatibility. @@ -41,7 +42,7 @@ class mytest(dbapi20.DatabaseAPI20Test): # The self.driver module. This should be the module where the 'connect' # method is to be found driver: Any = None - connect_args: Tuple = () # List of arguments to pass to connect + connect_args: tuple = () # List of arguments to pass to connect connect_kw_args: Mapping[str, Any] = {} # Keyword arguments for connect table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 7d4409df..242fdbb5 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -9,6 +9,8 @@ These tests need a database to test against. """ +from __future__ import annotations + import os import threading import time @@ -17,7 +19,7 @@ from collections.abc import Iterable from contextlib import suppress from decimal import Decimal -from typing import Sequence, Tuple +from typing import Sequence import pg # the module under test @@ -532,9 +534,9 @@ def test_namedresult_with_good_fieldnames(self): self.assertEqual(v._fields, ('snake_case_alias', 'CamelCaseAlias')) def test_namedresult_with_bad_fieldnames(self): - r = namedtuple('Bad', ['?'] * 6, rename=True) + t = namedtuple('Bad', ['?'] * 6, rename=True) # type: ignore # noinspection PyUnresolvedReferences - fields = r._fields + fields = t._fields q = ('select 3 as "0alias", 4 as _alias, 5 as "alias$", 6 as "alias?",' ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one') result = [tuple(range(3, 10))] @@ -820,45 +822,44 @@ def tearDown(self): def test_getresul_ascii(self): result = 'Hello, world!' - q = f"select '{result}'" - v = self.c.query(q).getresult()[0][0] + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) def test_dictresul_ascii(self): result = 'Hello, world!' - q = f"select '{result}' as greeting" - v = self.c.query(q).dictresult()[0]['greeting'] + cmd = f"select '{result}' as greeting" + v = self.c.query(cmd).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) def test_getresult_utf8(self): result = 'Hello, wörld & мир!' - q = f"select '{result}'" + cmd = f"select '{result}'" # pass the query as unicode try: - v = self.c.query(q).getresult()[0][0] + v = self.c.query(cmd).getresult()[0][0] except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode() - # pass the query as bytes - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode() + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) def test_dictresult_utf8(self): result = 'Hello, wörld & мир!' - q = f"select '{result}' as greeting" + cmd = f"select '{result}' as greeting" try: - v = self.c.query(q).dictresult()[0]['greeting'] + v = self.c.query(cmd).dictresult()[0]['greeting'] except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode() - v = self.c.query(q).dictresult()[0]['greeting'] + cmd_bytes = cmd.encode() + v = self.c.query(cmd_bytes).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -868,12 +869,12 @@ def test_getresult_latin1(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") result = 'Hello, wörld!' - q = f"select '{result}'" - v = self.c.query(q).getresult()[0][0] + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin1') - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode('latin1') + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -883,12 +884,12 @@ def test_dictresult_latin1(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") result = 'Hello, wörld!' - q = f"select '{result}' as greeting" - v = self.c.query(q).dictresult()[0]['greeting'] + cmd = f"select '{result}' as greeting" + v = self.c.query(cmd).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin1') - v = self.c.query(q).dictresult()[0]['greeting'] + cmd_bytes = cmd.encode('latin1') + v = self.c.query(cmd_bytes).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -898,12 +899,12 @@ def test_getresult_cyrillic(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") result = 'Hello, мир!' - q = f"select '{result}'" - v = self.c.query(q).getresult()[0][0] + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('cyrillic') - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode('cyrillic') + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -913,12 +914,12 @@ def test_dictresult_cyrillic(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") result = 'Hello, мир!' - q = f"select '{result}' as greeting" - v = self.c.query(q).dictresult()[0]['greeting'] + cmd = f"select '{result}' as greeting" + v = self.c.query(cmd).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('cyrillic') - v = self.c.query(q).dictresult()[0]['greeting'] + cmd_bytes = cmd.encode('cyrillic') + v = self.c.query(cmd_bytes).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -928,12 +929,12 @@ def test_getresult_latin9(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin9") result = 'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' - q = f"select '{result}'" - v = self.c.query(q).getresult()[0][0] + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin9') - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode('latin9') + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -943,12 +944,12 @@ def test_dictresult_latin9(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin9") result = 'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' - q = f"select '{result}' as menu" - v = self.c.query(q).dictresult()[0]['menu'] + cmd = f"select '{result}' as menu" + v = self.c.query(cmd).dictresult()[0]['menu'] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin9') - v = self.c.query(q).dictresult()[0]['menu'] + cmd_bytes = cmd.encode('latin9') + v = self.c.query(cmd_bytes).dictresult()[0]['menu'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -1698,6 +1699,7 @@ class TestInserttable(unittest.TestCase): """Test inserttable method.""" cls_set_up = False + has_encoding = False @classmethod def setUpClass(cls): @@ -1738,7 +1740,7 @@ def tearDown(self): self.c.query("truncate table test") self.c.close() - data: Sequence[Tuple] = [ + data: Sequence[tuple] = [ (-1, -1, -1, True, '1492-10-12', '08:30:00', -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), (0, 0, 0, False, '1607-04-14', '09:00:00', @@ -1868,7 +1870,7 @@ def test_inserttable_from_list_of_sets(self): def test_inserttable_multiple_rows(self): num_rows = 100 - data = self.data[2:3] * num_rows + data = list(self.data[2:3]) * num_rows self.c.inserttable('test', data) r = self.c.query("select count(*) from test").getresult()[0][0] self.assertEqual(r, num_rows) @@ -1892,13 +1894,13 @@ def test_inserttable_no_column(self): self.assertEqual(self.get_back(), []) def test_inserttable_only_one_column(self): - data = [(42,)] * 50 + data: list[tuple] = [(42,)] * 50 self.c.inserttable('test', data, ['i4']) data = [tuple([42 if i == 1 else None for i in range(14)])] * 50 self.assertEqual(self.get_back(), data) def test_inserttable_only_two_columns(self): - data = [(bool(i % 2), i * .5) for i in range(20)] + data: list[tuple] = [(bool(i % 2), i * .5) for i in range(20)] self.c.inserttable('test', data, ('b', 'f4')) # noinspection PyTypeChecker data = [(None,) * 3 + (bool(i % 2),) + (None,) * 3 + (i * .5,) @@ -2021,7 +2023,7 @@ def test_inserttable_unicode_latin1(self): self.skipTest("database does not support latin1") # non-ascii chars do not fit in char(1) when there is no encoding c = '€' if self.has_encoding else '$' - row_unicode = ( + row_unicode: tuple = ( 0, 0, 0, False, '1970-01-01', '00:00:00', 0.0, 0.0, 0.0, '0.0', c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 3884436f..31aec400 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -9,19 +9,21 @@ These tests need a database to test against. """ +from __future__ import annotations + import gc import json import os import sys import tempfile import unittest -from collections import OrderedDict from contextlib import suppress from datetime import date, datetime, time, timedelta from decimal import Decimal from io import StringIO from operator import itemgetter from time import strftime +from typing import Any, ClassVar from uuid import UUID import pg # the module under test @@ -51,20 +53,19 @@ class TestAttrDict(unittest.TestCase): """Test the simple ordered dictionary for attribute names.""" cls = pg.AttrDict - base = OrderedDict def test_init(self): a = self.cls() - self.assertIsInstance(a, self.base) - self.assertEqual(a, self.base()) + self.assertIsInstance(a, dict) + self.assertEqual(a, {}) items = [('id', 'int'), ('name', 'text')] a = self.cls(items) - self.assertIsInstance(a, self.base) - self.assertEqual(a, self.base(items)) + self.assertIsInstance(a, dict) + self.assertEqual(a, dict(items)) iteritems = iter(items) a = self.cls(iteritems) - self.assertIsInstance(a, self.base) - self.assertEqual(a, self.base(items)) + self.assertIsInstance(a, dict) + self.assertEqual(a, dict(items)) def test_iter(self): a = self.cls() @@ -127,7 +128,7 @@ def test_write_methods(self): self.assertEqual(a['id'], 1) for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': method = getattr(a, method) - self.assertRaises(TypeError, method, a) + self.assertRaises(TypeError, method, a) # type: ignore class TestDBClassInit(unittest.TestCase): @@ -193,7 +194,7 @@ def test_all_db_attributes(self): 'locreate', 'loimport', 'notification_handler', 'options', - 'parameter', 'pkey', 'poll', 'port', + 'parameter', 'pkey', 'pkeys', 'poll', 'port', 'prepare', 'protocol_version', 'putline', 'query', 'query_formatted', 'query_prepared', 'release', 'reopen', 'reset', 'rollback', @@ -416,11 +417,12 @@ class TestDBClass(unittest.TestCase): cls_set_up = False regtypes = None + supports_oids = False @classmethod def setUpClass(cls): db = DB() - cls.oids = db.server_version < 120000 + cls.supports_oids = db.server_version < 120000 db.query("drop table if exists test cascade") db.query("create table test (" "i2 smallint, i4 integer, i8 bigint," @@ -469,21 +471,21 @@ def create_table(self, table, definition, if not as_query and not definition.startswith('('): definition = f'({definition})' with_oids = 'with oids' if oids else ( - 'without oids' if self.oids else '') - q = ['create', temporary, table] + 'without oids' if self.supports_oids else '') + cmd_parts = ['create', temporary, table] if as_query: - q.extend([with_oids, definition]) + cmd_parts.extend([with_oids, definition]) else: - q.extend([definition, with_oids]) - q = ' '.join(q) - query(q) + cmd_parts.extend([definition, with_oids]) + cmd = ' '.join(cmd_parts) + query(cmd) if values: for params in values: if not isinstance(params, (list, tuple)): params = [params] values = ', '.join(f'${n + 1}' for n in range(len(params))) - q = f"insert into {table} values ({values})" - query(q, params) + cmd = f"insert into {table} values ({values})" + query(cmd, params) def test_class_name(self): self.assertEqual(self.db.__class__.__name__, 'DB') @@ -494,7 +496,7 @@ def test_module_name(self): def test_escape_literal(self): f = self.db.escape_literal - r = f(b"plain") + r: Any = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b"'plain'") r = f("plain") @@ -846,7 +848,7 @@ def test_create_table(self): self.assertEqual(r, "Hello, World!") def test_create_table_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") table = 'test hello world' values = [(2, "World!"), (1, "Hello")] @@ -893,7 +895,7 @@ def test_query(self): self.assertEqual(r, '5') def test_query_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") query = self.db.query table = 'test_table' @@ -1175,6 +1177,16 @@ def test_pkey(self): # we get the changed primary key when the cache is flushed self.assertEqual(pkey(f'{t}1', flush=True), 'x') + def test_pkeys(self): + pkeys = self.db.pkeys + t = 'pkeys_test_' + self.create_table(f'{t}0', 'a int') + self.create_table(f'{t}1', 'a int primary key, b int') + self.create_table(f'{t}2', 'a int, b int, c int, primary key (a, c)') + self.assertRaises(KeyError, pkeys, f'{t}0') + self.assertEqual(pkeys(f'{t}1'), ('a',)) + self.assertEqual(pkeys(f'{t}2'), ('a', 'c')) + def test_get_databases(self): databases = self.db.get_databases() self.assertIn('template0', databases) @@ -1194,11 +1206,11 @@ def test_get_tables(self): before_tables = get_tables() self.assertIsInstance(before_tables, list) for t in before_tables: - t = t.split('.', 1) - self.assertGreaterEqual(len(t), 2) - if len(t) > 2: - self.assertTrue(t[1].startswith('"')) - t = t[0] + s = t.split('.', 1) + self.assertGreaterEqual(len(s), 2) + if len(s) > 2: + self.assertTrue(s[1].startswith('"')) + t = s[0] self.assertNotEqual(t, 'information_schema') self.assertFalse(t.startswith('pg_')) for t in tables: @@ -1392,41 +1404,37 @@ def test_get_attnames_is_cached(self): def test_get_attnames_is_ordered(self): get_attnames = self.db.get_attnames r = get_attnames('test', flush=True) - self.assertIsInstance(r, OrderedDict) + self.assertIsInstance(r, dict) if self.regtypes: - self.assertEqual(r, OrderedDict([ - ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'), - ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'), - ('m', 'money'), ('v4', 'character varying'), - ('c4', 'character'), ('t', 'text')])) + self.assertEqual(r, { + 'i2': 'smallint', 'i4': 'integer', 'i8': 'bigint', + 'd': 'numeric', 'f4': 'real', 'f8': 'double precision', + 'm': 'money', 'v4': 'character varying', + 'c4': 'character', 't': 'text'}) else: - self.assertEqual(r, OrderedDict([ - ('i2', 'int'), ('i4', 'int'), ('i8', 'int'), - ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'), - ('v4', 'text'), ('c4', 'text'), ('t', 'text')])) - if OrderedDict is not dict: - r = ' '.join(list(r.keys())) - self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') + self.assertEqual(r, { + 'i2': 'int', 'i4': 'int', 'i8': 'int', + 'd': 'num', 'f4': 'float', 'f8': 'float', 'm': 'money', + 'v4': 'text', 'c4': 'text', 't': 'text'}) + r = ' '.join(list(r.keys())) + self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' self.create_table( table, 'n int, alpha smallint, v varchar(3),' ' gamma char(5), tau text, beta bool') r = get_attnames(table) - self.assertIsInstance(r, OrderedDict) + self.assertIsInstance(r, dict) if self.regtypes: - self.assertEqual(r, OrderedDict([ - ('n', 'integer'), ('alpha', 'smallint'), - ('v', 'character varying'), ('gamma', 'character'), - ('tau', 'text'), ('beta', 'boolean')])) - else: - self.assertEqual(r, OrderedDict([ - ('n', 'int'), ('alpha', 'int'), ('v', 'text'), - ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')])) - if OrderedDict is not dict: - r = ' '.join(list(r.keys())) - self.assertEqual(r, 'n alpha v gamma tau beta') + self.assertEqual(r, { + 'n': 'integer', 'alpha': 'smallint', + 'v': 'character varying', 'gamma': 'character', + 'tau': 'text', 'beta': 'boolean'}) else: - self.skipTest('OrderedDict is not supported') + self.assertEqual(r, { + 'n': 'int', 'alpha': 'int', 'v': 'text', + 'gamma': 'text', 'tau': 'text', 'beta': 'bool'}) + r = ' '.join(list(r.keys())) + self.assertEqual(r, 'n alpha v gamma tau beta') def test_get_attnames_is_attr_dict(self): AttrDict = pg.AttrDict # noqa: N806 @@ -1541,7 +1549,7 @@ def test_get(self): self.create_table(table, 'n integer, t text', values=enumerate('xyz', start=1)) self.assertRaises(pg.ProgrammingError, get, table, 2) - r = get(table, 2, 'n') + r: Any = get(table, 2, 'n') self.assertIsInstance(r, dict) self.assertEqual(r, dict(n=2, t='y')) r = get(table, 1, 'n') @@ -1554,7 +1562,7 @@ def test_get(self): self.assertRaises(pg.DatabaseError, get, table, 4, 'n') self.assertRaises(pg.DatabaseError, get, table, 'y') self.assertRaises(pg.DatabaseError, get, table, 2, 't') - s = dict(n=3) + s: dict = dict(n=3) self.assertRaises(pg.ProgrammingError, get, table, s) r = get(table, s, 'n') self.assertIs(r, s) @@ -1588,7 +1596,7 @@ def test_get(self): self.assertRaises(KeyError, get, table, s) def test_get_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") get = self.db.get query = self.db.query @@ -1753,7 +1761,7 @@ def test_insert(self): ' d numeric, f4 real, f8 double precision, m money,' ' v4 varchar(4), c4 char(4), t text,' ' b boolean, ts timestamp') - tests = [ + tests: list[dict | tuple[dict, dict]] = [ dict(i2=None, i4=None, i8=None), (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)), (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)), @@ -1798,8 +1806,8 @@ def test_insert(self): dict(ts='current_timestamp')] for test in tests: if isinstance(test, dict): - data = test - change = {} + data: dict = test + change: dict = {} else: data, change = test expect = data.copy() @@ -1835,7 +1843,7 @@ def test_insert(self): query(f'truncate table "{table}"') def test_insert_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") insert = self.db.insert query = self.db.query @@ -1910,7 +1918,7 @@ def test_insert_with_quoted_names(self): table = 'test table for insert()' self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') - r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'} + r: Any = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'} r = insert(table, r) self.assertIsInstance(r, dict) self.assertEqual(r['Prime!'], 11) @@ -1928,7 +1936,7 @@ def test_insert_into_view(self): query = self.db.query query("truncate table test") q = 'select * from test_view order by i4 limit 3' - r = query(q).getresult() + r: Any = query(q).getresult() self.assertEqual(r, []) r = dict(i4=1234, v4='abcd') insert('test', r) @@ -1993,7 +2001,7 @@ def test_update(self): self.assertEqual(r, 'u') def test_update_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") update = self.db.update get = self.db.get @@ -2133,7 +2141,7 @@ def test_update_with_quoted_names(self): self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(13, 3003, 'Why!')]) - r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'} + r: Any = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'} r = update(table, r) self.assertIsInstance(r, dict) self.assertEqual(r['Prime!'], 13) @@ -2166,7 +2174,7 @@ def test_update_with_generated_columns(self): self.create_table(table, table_def) i, d = 35, 1001 j = i + 7 - r = query(f'insert into {table} (i, d) values ({i}, {d})') + r: Any = query(f'insert into {table} (i, d) values ({i}, {d})') self.assertEqual(r, '1') r = get(table, d) self.assertIsInstance(r, dict) @@ -2185,8 +2193,8 @@ def test_upsert(self): 'test', i2=2, i4=4, i8=8) table = 'upsert_test_table' self.create_table(table, 'n integer primary key, t text') - s = dict(n=1, t='x') - r = upsert(table, s) + s: dict = dict(n=1, t='x') + r: Any = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['n'], 1) self.assertEqual(r['t'], 'x') @@ -2252,7 +2260,7 @@ def test_upsert(self): self.assertEqual(r, [(1, 'x2'), (2, 'y3')]) def test_upsert_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") upsert = self.db.upsert get = self.db.get @@ -2260,7 +2268,7 @@ def test_upsert_with_oids(self): self.create_table('test_table', 'n int', oids=True, values=[1]) self.assertRaises(pg.ProgrammingError, upsert, 'test_table', dict(n=2)) - r = get('test_table', 1, 'n') + r: Any = get('test_table', 1, 'n') self.assertIsInstance(r, dict) self.assertEqual(r['n'], 1) qoid = 'oid(test_table)' @@ -2338,8 +2346,8 @@ def test_upsert_with_composite_key(self): table = 'upsert_test_table_2' self.create_table( table, 'n integer, m integer, t text, primary key (n, m)') - s = dict(n=1, m=2, t='x') - r = upsert(table, s) + s: dict = dict(n=1, m=2, t='x') + r: Any = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['n'], 1) self.assertEqual(r['m'], 2) @@ -2400,8 +2408,8 @@ def test_upsert_with_quoted_names(self): table = 'test table for upsert()' self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') - s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'} - r = upsert(table, s) + s: dict = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'} + r: Any = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['Prime!'], 31) self.assertEqual(r['much space'], 9009) @@ -2437,7 +2445,7 @@ def test_upsert_with_generated_columns(self): self.create_table(table, table_def) i, d = 35, 1001 j = i + 7 - r = upsert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) + r: Any = upsert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) self.assertIsInstance(r, dict) self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) r['i'] += 1 @@ -2452,7 +2460,7 @@ def test_upsert_with_generated_columns(self): def test_clear(self): clear = self.db.clear f = False if pg.get_bool() else 'f' - r = clear('test') + r: Any = clear('test') result = dict( i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='') self.assertEqual(r, result) @@ -2491,8 +2499,8 @@ def test_delete(self): self.create_table(table, 'n integer primary key, t text', oids=False, values=enumerate('xyz', start=1)) self.assertRaises(pg.DatabaseError, self.db.get, table, 4) - r = self.db.get(table, 1) - s = delete(table, r) + r: Any = self.db.get(table, 1) + s: Any = delete(table, r) self.assertEqual(s, 1) r = self.db.get(table, 3) s = delete(table, r) @@ -2516,15 +2524,15 @@ def test_delete(self): self.assertEqual(s, 0) def test_delete_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") delete = self.db.delete get = self.db.get query = self.db.query self.create_table('test_table', 'n int', oids=True, values=range(1, 7)) - r = dict(n=3) + r: Any = dict(n=3) self.assertRaises(pg.ProgrammingError, delete, 'test_table', r) - s = get('test_table', 1, 'n') + s: Any = get('test_table', 1, 'n') qoid = 'oid(test_table)' self.assertIn(qoid, s) r = delete('test_table', s) @@ -2618,7 +2626,7 @@ def test_delete_with_composite_key(self): values=enumerate('abc', start=1)) self.assertRaises(KeyError, self.db.delete, table, dict(t='b')) self.assertEqual(self.db.delete(table, dict(n=2)), 1) - r = query(f'select t from "{table}" where n=2').getresult() + r: Any = query(f'select t from "{table}" where n=2').getresult() self.assertEqual(r, []) self.assertEqual(self.db.delete(table, dict(n=2)), 0) r = query(f'select t from "{table}" where n=3').getresult()[0][0] @@ -2650,7 +2658,7 @@ def test_delete_with_quoted_names(self): table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(19, 5005, 'Yes!')]) - r = {'Prime!': 17} + r: Any = {'Prime!': 17} r = delete(table, r) self.assertEqual(r, 0) r = query(f'select count(*) from "{table}"').getresult() @@ -2676,7 +2684,7 @@ def test_delete_referenced(self): delete, 'test_parent', None, n=2) self.assertRaises(pg.IntegrityError, delete, 'test_parent *', None, n=2) - r = delete('test_child', None, n=2) + r: Any = delete('test_child', None, n=2) self.assertEqual(r, 1) self.assertEqual(query(q).getresult()[0], (3, 2)) r = delete('test_parent', None, n=2) @@ -2706,7 +2714,7 @@ def test_temp_crud(self): self.db.insert(table, dict(n=1, t='one')) self.db.insert(table, dict(n=2, t='too')) self.db.insert(table, dict(n=3, t='three')) - r = self.db.get(table, 2) + r: Any = self.db.get(table, 2) self.assertEqual(r['t'], 'too') self.db.update(table, dict(n=2, t='two')) r = self.db.get(table, 2) @@ -2724,7 +2732,7 @@ def test_truncate(self): self.create_table('test_table', 'n smallint', temporary=False, values=[1] * 3) q = "select count(*) from test_table" - r = query(q).getresult()[0][0] + r: Any = query(q).getresult()[0][0] self.assertEqual(r, 3) truncate('test_table') r = query(q).getresult()[0][0] @@ -2757,7 +2765,7 @@ def test_truncate_restart(self): for _n in range(3): query("insert into test_table (t) values ('test')") q = "select count(n), min(n), max(n) from test_table" - r = query(q).getresult()[0] + r: Any = query(q).getresult()[0] self.assertEqual(r, (3, 1, 3)) truncate('test_table') r = query(q).getresult()[0] @@ -2785,7 +2793,7 @@ def test_truncate_cascade(self): values=range(3)) q = ("select (select count(*) from test_parent)," " (select count(*) from test_child)") - r = query(q).getresult()[0] + r: Any = query(q).getresult()[0] self.assertEqual(r, (3, 3)) self.assertRaises(pg.NotSupportedError, truncate, 'test_parent') truncate(['test_parent', 'test_child']) @@ -2899,7 +2907,7 @@ def test_get_as_list(self): self.assertRaises(TypeError, get_as_list, None) query = self.db.query table = 'test_aslist' - r = query('select 1 as colname').namedresult()[0] + r: Any = query('select 1 as colname').namedresult()[0] self.assertIsInstance(r, tuple) named = hasattr(r, 'colname') names = [(1, 'Homer'), (2, 'Marge'), @@ -2918,7 +2926,7 @@ def test_get_as_list(self): self.assertEqual(t._asdict(), dict(id=n[0], name=n[1])) r = get_as_list(table, what='name') self.assertIsInstance(r, list) - expected = sorted((row[1],) for row in names) + expected: Any = sorted((row[1],) for row in names) self.assertEqual(r, expected) r = get_as_list(table, what='name, id') self.assertIsInstance(r, list) @@ -3029,8 +3037,8 @@ def test_get_as_dict(self): self.assertRaises(KeyError, get_as_dict, table, keyname='rgb', what='name') r = get_as_dict(table) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[0], row[1:]) for row in colors) + self.assertIsInstance(r, dict) + expected: Any = {row[0]: row[1:] for row in colors} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, int) @@ -3045,9 +3053,9 @@ def test_get_as_dict(self): self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1])) self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname='rgb') - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[1], (row[0], row[2])) - for row in sorted(colors, key=itemgetter(1))) + self.assertIsInstance(r, dict) + expected = {row[1]: (row[0], row[2]) + for row in sorted(colors, key=itemgetter(1))} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, str) @@ -3063,8 +3071,8 @@ def test_get_as_dict(self): self.assertEqual(row._asdict(), dict(id=t[0], name=t[1])) self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname=['id', 'rgb']) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[:2], row[2:]) for row in colors) + self.assertIsInstance(r, dict) + expected = {row[:2]: row[2:] for row in colors} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, tuple) @@ -3084,8 +3092,8 @@ def test_get_as_dict(self): self.assertEqual(row._asdict(), dict(name=t[0])) self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[:2], row[2]) for row in colors) + self.assertIsInstance(r, dict) + expected = {row[:2]: row[2] for row in colors} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, tuple) @@ -3097,9 +3105,9 @@ def test_get_as_dict(self): self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict( - (row[1], row[2]) for row in sorted(colors, key=itemgetter(1))) + self.assertIsInstance(r, dict) + expected = {row[1]: row[2] + for row in sorted(colors, key=itemgetter(1))} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, str) @@ -3111,8 +3119,8 @@ def test_get_as_dict(self): self.assertEqual(r.keys(), expected.keys()) r = get_as_dict( table, what='id, name', where="rgb like '#b%'", scalar=True) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[0], row[2]) for row in colors[1:3]) + self.assertIsInstance(r, dict) + expected = {row[0]: row[2] for row in colors[1:3]} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, int) @@ -3140,31 +3148,31 @@ def test_get_as_dict(self): self.assertEqual(len(r), 1) self.assertEqual(r[4][1], 'Desert') r = get_as_dict(table, order='id desc') - expected = OrderedDict((row[0], row[1:]) for row in reversed(colors)) + expected = {row[0]: row[1:] for row in reversed(colors)} self.assertEqual(r, expected) r = get_as_dict(table, where='id > 5') - self.assertIsInstance(r, OrderedDict) + self.assertIsInstance(r, dict) self.assertEqual(len(r), 0) # test with unordered query expected = {row[0]: row[1:] for row in colors} r = get_as_dict(table, order=False) self.assertIsInstance(r, dict) self.assertEqual(r, expected) - self.assertNotIsInstance(self, OrderedDict) + self.assertNotIsInstance(self, dict) # test with arbitrary from clause from_table = f'(select id, lower(name) as n2 from "{table}") as t2' # primary key must be passed explicitly in this case self.assertRaises(pg.ProgrammingError, get_as_dict, from_table) r = get_as_dict(from_table, 'id') - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors) + self.assertIsInstance(r, dict) + expected = {row[0]: (row[2].lower(),) for row in colors} self.assertEqual(r, expected) # test without a primary key query(f'alter table "{table}" drop constraint "{table}_pkey"') self.assertRaises(KeyError, self.db.pkey, table, flush=True) self.assertRaises(pg.ProgrammingError, get_as_dict, table) r = get_as_dict(table, keyname='id') - expected = OrderedDict((row[0], row[1:]) for row in colors) + expected = {row[0]: row[1:] for row in colors} self.assertIsInstance(r, dict) self.assertEqual(r, expected) r = (1, '#007fff', 'Azure') @@ -3783,14 +3791,17 @@ def test_insert_update_get_record(self): name='text', age='int', married='bool', weight='float', salary='money')) decimal = pg.get_decimal() + bool_class: type + t: bool | str + f: bool | str if pg.get_bool(): bool_class = bool t, f = True, False else: bool_class = str t, f = 't', 'f' - person = ('John Doe', 61, t, 99.5, decimal('93456.75')) - r = self.db.insert('test_person', None, person=person) + person: tuple = ('John Doe', 61, t, 99.5, decimal('93456.75')) + r: Any = self.db.insert('test_person', None, person=person) self.assertEqual(r['id'], 1) p = r['person'] self.assertIsInstance(p, tuple) @@ -4301,9 +4312,11 @@ def test_inserttable_from_query(self): class TestDBClassNonStdOpts(TestDBClass): """Test the methods of the DB class with non-standard global options.""" + saved_options: ClassVar[dict[str, Any]] = {} + @classmethod def setUpClass(cls): - cls.saved_options = {} + cls.saved_options.clear() cls.set_option('decimal', float) not_bool = not pg.get_bool() cls.set_option('bool', not_bool) @@ -4375,8 +4388,8 @@ def test_adapt_query_typed_list(self): self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), ('int2',)) self.assertRaises( TypeError, format_query, '%s,%s', (1,), ('int2', 'int2')) - values = (3, 7.5, 'hello', True) - types = ('int4', 'float4', 'text', 'bool') + values: list | tuple = (3, 7.5, 'hello', True) + types: list | tuple = ('int4', 'float4', 'text', 'bool') sql, params = format_query("select %s,%s,%s,%s", values, types) self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) @@ -4434,7 +4447,7 @@ def test_adapt_query_typed_list_with_types_as_classes(self): def test_adapt_query_typed_list_with_json(self): format_query = self.adapter.format_query - value = {'test': [1, "it's fine", 3]} + value: Any = {'test': [1, "it's fine", 3]} sql, params = format_query("select %s", (value,), 'json') self.assertEqual(sql, 'select $1') self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) @@ -4449,7 +4462,7 @@ def test_adapt_query_typed_list_with_json(self): def test_adapt_query_typed_with_hstore(self): format_query = self.adapter.format_query - value = {'one': "it's fine", 'two': 2} + value: Any = {'one': "it's fine", 'two': 2} sql, params = format_query("select %s", (value,), 'hstore') self.assertEqual(sql, "select $1") self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) @@ -4464,7 +4477,7 @@ def test_adapt_query_typed_with_hstore(self): def test_adapt_query_typed_with_uuid(self): format_query = self.adapter.format_query - value = '12345678-1234-5678-1234-567812345678' + value: Any = '12345678-1234-5678-1234-567812345678' sql, params = format_query("select %s", (value,), 'uuid') self.assertEqual(sql, "select $1") self.assertEqual(params, ['12345678-1234-5678-1234-567812345678']) @@ -4482,8 +4495,8 @@ def test_adapt_query_typed_dict(self): self.assertRaises( TypeError, format_query, '%s,%s', dict(i1=1, i2=2), dict(i1='int2')) - values = dict(i=3, f=7.5, t='hello', b=True) - types = dict(i='int4', f='float4', t='text', b='bool') + values: dict = dict(i=3, f=7.5, t='hello', b=True) + types: dict = dict(i='int4', f='float4', t='text', b='bool') sql, params = format_query( "select %(i)s,%(f)s,%(t)s,%(b)s", values, types) self.assertEqual(sql, 'select $3,$2,$4,$1') @@ -4523,7 +4536,7 @@ def test_adapt_query_typed_dict(self): def test_adapt_query_untyped_list(self): format_query = self.adapter.format_query - values = (3, 7.5, 'hello', True) + values: list | tuple = (3, 7.5, 'hello', True) sql, params = format_query("select %s,%s,%s,%s", values) self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) @@ -4562,7 +4575,7 @@ def test_adapt_query_untyped_with_hstore(self): def test_adapt_query_untyped_dict(self): format_query = self.adapter.format_query - values = dict(i=3, f=7.5, t='hello', b=True) + values: dict = dict(i=3, f=7.5, t='hello', b=True) sql, params = format_query( "select %(i)s,%(f)s,%(t)s,%(b)s", values) self.assertEqual(sql, 'select $3,$2,$4,$1') @@ -4589,7 +4602,7 @@ def test_adapt_query_untyped_dict(self): def test_adapt_query_inline_list(self): format_query = self.adapter.format_query - values = (3, 7.5, 'hello', True) + values: list | tuple = (3, 7.5, 'hello', True) sql, params = format_query("select %s,%s,%s,%s", values, inline=True) self.assertEqual(sql, "select 3,7.5,'hello',true") self.assertEqual(params, []) @@ -4633,7 +4646,7 @@ def test_adapt_query_inline_list_with_hstore(self): def test_adapt_query_inline_dict(self): format_query = self.adapter.format_query - values = dict(i=3, f=7.5, t='hello', b=True) + values: dict = dict(i=3, f=7.5, t='hello', b=True) sql, params = format_query( "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True) self.assertEqual(sql, "select 3,7.5,'hello',true") @@ -4683,6 +4696,7 @@ class TestSchemas(unittest.TestCase): """Test correct handling of schemas (namespaces).""" cls_set_up = False + with_oids = "" @classmethod def setUpClass(cls): @@ -4823,11 +4837,11 @@ def test_query_information_schema(self): class TestDebug(unittest.TestCase): """Test the debug attribute of the DB class.""" - + def setUp(self): self.db = DB() self.query = self.db.query - self.debug = self.db.debug + self.debug = self.db.debug # type: ignore self.output = StringIO() self.stdout, sys.stdout = sys.stdout, self.output @@ -4877,7 +4891,7 @@ def test_debug_is_file_like(self): self.assertEqual(self.get_output(), "") def test_debug_is_callable(self): - output = [] + output: list[str] = [] self.db.debug = output.append self.db.query("select 1") self.db.query("select 2") @@ -4885,7 +4899,7 @@ def test_debug_is_callable(self): self.assertEqual(self.get_output(), "") def test_debug_multiple_args(self): - output = [] + output: list[str] = [] self.db.debug = output.append args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]] self.db._do_debug(*args) @@ -4897,8 +4911,8 @@ class TestMemoryLeaks(unittest.TestCase): """Test that the DB class does not leak memory.""" def get_leaks(self, fut): - ids = set() - objs = [] + ids: set = set() + objs: list = [] add_ids = ids.update gc.collect() objs[:] = gc.get_objects() diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 37606b13..33c2f6f9 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -9,11 +9,13 @@ These tests do not need a database to test against. """ +from __future__ import annotations + import json import re import unittest from datetime import timedelta -from typing import Any, Sequence, Tuple, Type +from typing import Any, Sequence import pg # the module under test @@ -21,56 +23,64 @@ class TestHasConnect(unittest.TestCase): """Test existence of basic pg module functions.""" - def testhas_pg_error(self): + def test_has_pg_error(self): self.assertTrue(issubclass(pg.Error, Exception)) - def testhas_pg_warning(self): + def test_has_pg_warning(self): self.assertTrue(issubclass(pg.Warning, Exception)) - def testhas_pg_interface_error(self): + def test_has_pg_interface_error(self): self.assertTrue(issubclass(pg.InterfaceError, pg.Error)) - def testhas_pg_database_error(self): + def test_has_pg_database_error(self): self.assertTrue(issubclass(pg.DatabaseError, pg.Error)) - def testhas_pg_internal_error(self): + def test_has_pg_internal_error(self): self.assertTrue(issubclass(pg.InternalError, pg.DatabaseError)) - def testhas_pg_operational_error(self): + def test_has_pg_operational_error(self): self.assertTrue(issubclass(pg.OperationalError, pg.DatabaseError)) - def testhas_pg_programming_error(self): + def test_has_pg_programming_error(self): self.assertTrue(issubclass(pg.ProgrammingError, pg.DatabaseError)) - def testhas_pg_integrity_error(self): + def test_has_pg_integrity_error(self): self.assertTrue(issubclass(pg.IntegrityError, pg.DatabaseError)) - def testhas_pg_data_error(self): + def test_has_pg_data_error(self): self.assertTrue(issubclass(pg.DataError, pg.DatabaseError)) - def testhas_pg_not_supported_error(self): + def test_has_pg_not_supported_error(self): self.assertTrue(issubclass(pg.NotSupportedError, pg.DatabaseError)) - def testhas_pg_invalid_result_error(self): + def test_has_pg_invalid_result_error(self): self.assertTrue(issubclass(pg.InvalidResultError, pg.DataError)) - def testhas_pg_no_result_error(self): + def test_has_pg_no_result_error(self): self.assertTrue(issubclass(pg.NoResultError, pg.InvalidResultError)) - def testhas_pg_multiple_results_error(self): + def test_has_pg_multiple_results_error(self): self.assertTrue( issubclass(pg.MultipleResultsError, pg.InvalidResultError)) - def testhas_connect(self): + def test_has_connection_type(self): + self.assertIsInstance(pg.Connection, type) + self.assertEqual(pg.Connection.__name__, 'Connection') + + def test_has_query_type(self): + self.assertIsInstance(pg.Query, type) + self.assertEqual(pg.Query.__name__, 'Query') + + def test_has_connect(self): self.assertTrue(callable(pg.connect)) - def testhas_escape_string(self): + def test_has_escape_string(self): self.assertTrue(callable(pg.escape_string)) - def testhas_escape_bytea(self): + def test_has_escape_bytea(self): self.assertTrue(callable(pg.escape_bytea)) - def testhas_unescape_bytea(self): + def test_has_unescape_bytea(self): self.assertTrue(callable(pg.unescape_bytea)) def test_def_host(self): @@ -120,7 +130,7 @@ def test_pqlib_version(self): class TestParseArray(unittest.TestCase): """Test the array parser.""" - test_strings: Sequence[Tuple[str, Type, Any]] = [ + test_strings: Sequence[tuple[str, type | None, Any]] = [ ('', str, ValueError), ('{}', None, []), ('{}', str, []), @@ -354,7 +364,7 @@ def replace_comma(value): class TestParseRecord(unittest.TestCase): """Test the record parser.""" - test_strings: Sequence[Tuple[str, Type, Any]] = [ + test_strings: Sequence[tuple[str, type | tuple[type, ...] | None, Any]] = [ ('', None, ValueError), ('', str, ValueError), ('(', None, ValueError), @@ -635,7 +645,7 @@ def replace_comma(value): class TestParseHStore(unittest.TestCase): """Test the hstore parser.""" - test_strings: Sequence[Tuple[str, Any]] = [ + test_strings: Sequence[tuple[str, Any]] = [ ('', {}), ('=>', ValueError), ('""=>', ValueError), @@ -684,7 +694,7 @@ def test_parser(self): class TestCastInterval(unittest.TestCase): """Test the interval typecast function.""" - intervals: Sequence[Tuple[Tuple[int, ...], Tuple[str, ...]]] = [ + intervals: Sequence[tuple[tuple[int, ...], tuple[str, ...]]] = [ ((0, 0, 0, 1, 0, 0, 0), ('1:00:00', '01:00:00', '@ 1 hour', 'PT1H')), ((0, 0, 0, -1, 0, 0, 0), diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index 7e5ad4a2..4fb8773c 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -13,6 +13,7 @@ import tempfile import unittest from contextlib import suppress +from typing import Any import pg # the module under test @@ -105,6 +106,7 @@ def test_get_lo(self): self.assertEqual(r, data) def test_lo_import(self): + f : Any if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_import.txt' @@ -412,6 +414,7 @@ def test_export(self): self.assertRaises(TypeError, export) self.assertRaises(TypeError, export, 0) self.assertRaises(TypeError, export, 'invalid', 0) + f: Any if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_export.txt' diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index bcacd476..09211718 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -9,10 +9,12 @@ These tests need a database to test against. """ +from __future__ import annotations # + import unittest from collections.abc import Iterable from contextlib import suppress -from typing import Sequence, Tuple +from typing import Sequence import pgdb # the module under test @@ -150,7 +152,7 @@ def tearDown(self): with suppress(Exception): self.con.close() - data: Sequence[Tuple[int, str]] = [ + data: Sequence[tuple[int, str]] = [ (1935, 'Luciano Pavarotti'), (1941, 'Plácido Domingo'), (1946, 'José Carreras')] diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 3f76f39b..c28fbefc 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -1,6 +1,7 @@ #!/usr/bin/python import unittest +from typing import Any from pg import DB from pgdb import connect @@ -29,7 +30,7 @@ def tearDown(self): def test_all_steps(self): db = self.db - r = db.get_tables() + r: Any = db.get_tables() self.assertIsInstance(r, list) self.assertIn('public.fruits', r) r = db.get_attnames('fruits') diff --git a/tox.ini b/tox.ini index 37b3a39d..7e52747d 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,12 @@ deps = ruff>=0.0.287 commands = ruff setup.py pg.py pgdb.py tests +[testenv:mypy] +basepython = python3.11 +deps = mypy>=1.5.1 +commands = + mypy setup.py pg.py pgdb.py tests + [testenv:cformat] basepython = python3.11 allowlist_externals = From 08c43d8fef8a2bc23284d09a0b5d637e961d042a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 4 Sep 2023 00:55:43 +0200 Subject: [PATCH 150/194] Add type hints for the pgdb module --- docs/contents/changelog.rst | 3 +- docs/contents/pg/db_wrapper.rst | 2 + pg.py | 46 +-- pgdb.py | 613 +++++++++++++++++--------------- pgsource.c | 3 +- tests/dbapi20.py | 6 +- tests/test_classic_dbwrapper.py | 4 +- tests/test_dbapi20.py | 116 +++--- tests/test_dbapi20_copy.py | 35 +- tests/test_tutorial.py | 2 +- 10 files changed, 458 insertions(+), 372 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index d240daa2..6afc68dd 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -5,8 +5,9 @@ Version 6.0 (to be released) ---------------------------- - Removed support for Python versions older than 3.7 (released June 2017) and PostgreSQL older than version 10 (released October 2017). +- Added method `pkeys()` to the `pg.DB` object. - Removed deprecated function `pg.pgnotify()`. -- Removed the deprecated method `ntuples()` of the `pg.Query` object. +- Removed deprecated method `ntuples()` of the `pg.Query` object. - Renamed `pgdb.Type` to `pgdb.DbType` to avoid confusion with `typing.Type`. - Modernized code and tools for development, testing, linting and building. diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 64710456..ea4f71c1 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -83,6 +83,8 @@ This method returns the primary keys of a table as a tuple, i.e. single primary keys are also returned as a tuple with one item. Note that this raises a KeyError if the table does not have a primary key. +.. versionadded:: 6.0 + get_databases -- get list of databases in the system ---------------------------------------------------- diff --git a/pg.py b/pg.py index 11aaf90a..45f8ae46 100644 --- a/pg.py +++ b/pg.py @@ -242,7 +242,8 @@ def __str__(self) -> str: class Json: """Wrapper class for marking Json values.""" - def __init__(self, obj: Any, encode: Callable | None = None) -> None: + def __init__(self, obj: Any, + encode: Callable[[Any], str] | None = None) -> None: """Initialize the JSON object.""" self.obj = obj self.encode = encode or jsonencode @@ -1014,7 +1015,7 @@ class Typecasts(dict): connection: DB | None = None # set in a connection specific instance - def __missing__(self, typ: Any) -> Callable | None: + def __missing__(self, typ: str) -> Callable | None: """Create a cast function if it is not cached. Note that this class never raises a KeyError, @@ -1047,8 +1048,7 @@ def _needs_connection(func: Callable) -> bool: args = get_args(func) except (TypeError, ValueError): return False - else: - return 'connection' in args[1:] + return 'connection' in args[1:] def _add_connection(self, cast: Callable) -> Callable: """Add a connection argument to the typecast function if necessary.""" @@ -1056,11 +1056,12 @@ def _add_connection(self, cast: Callable) -> Callable: return cast return partial(cast, connection=self.connection) - def get(self, typ: Any, default: Any = None) -> Any: + def get(self, typ: str, default: Callable | None = None # type: ignore + ) -> Callable | None: """Get the typecast function for the given database type.""" return self[typ] or default - def set(self, typ: Any, cast: Callable) -> None: + def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: """Set a typecast function for the specified database type(s).""" if isinstance(typ, str): typ = [typ] @@ -1075,7 +1076,7 @@ def set(self, typ: Any, cast: Callable) -> None: self[t] = self._add_connection(cast) self.pop(f'_{t}', None) - def reset(self, typ: Any = None) -> None: + def reset(self, typ: str | Sequence[str] | None = None) -> None: """Reset the typecasts for the specified type(s) to their defaults. When no type is specified, all typecasts will be reset. @@ -1089,12 +1090,13 @@ def reset(self, typ: Any = None) -> None: self.pop(t, None) @classmethod - def get_default(cls, typ: Any) -> Any: + def get_default(cls, typ: str) -> Any: """Get the default typecast function for the given database type.""" return cls.defaults.get(typ) @classmethod - def set_default(cls, typ: Any, cast: Callable | None) -> None: + def set_default(cls, typ: str | Sequence[str], + cast: Callable | None) -> None: """Set a default typecast function for the given database type(s).""" if isinstance(typ, str): typ = [typ] @@ -1130,7 +1132,7 @@ def create_array_cast(self, basecast: Callable) -> Callable: """Create an array typecast for the given base cast.""" cast_array = self['anyarray'] - def cast(v: Any) -> Callable: + def cast(v: Any) -> list: return cast_array(v, basecast) return cast @@ -1146,12 +1148,12 @@ def cast(v: Any) -> record: return cast -def get_typecast(typ: Any) -> Callable | None: - """Get the global typecast function for the given database type(s).""" +def get_typecast(typ: str) -> Callable | None: + """Get the global typecast function for the given database type.""" return Typecasts.get_default(typ) -def set_typecast(typ: Any, cast: Callable | None) -> None: +def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: """Set a global typecast function for the given database type(s). Note that connections cache cast functions. To be sure a global change @@ -1254,7 +1256,8 @@ def __missing__(self, key: int | str) -> DbType: self[typ.oid] = self[typ.pgtype] = typ return typ - def get(self, key: int | str, default: Any = None) -> Any: + def get(self, key: int | str, # type: ignore + default: DbType | None = None) -> DbType | None: """Get the type even if it is not cached.""" try: return self[key] @@ -1271,27 +1274,27 @@ def get_attnames(self, typ: Any) -> AttrDict | None: return None return self._db.get_attnames(typ.relid, with_oid=False) - def get_typecast(self, typ: Any) -> Callable: + def get_typecast(self, typ: Any) -> Callable | None: """Get the typecast function for the given database type.""" return self._typecasts.get(typ) - def set_typecast(self, typ: Any, cast: Callable) -> None: + def set_typecast(self, typ: str | Sequence[str], cast: Callable) -> None: """Set a typecast function for the specified database type(s).""" self._typecasts.set(typ, cast) - def reset_typecast(self, typ: Any = None) -> None: + def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: """Reset the typecast function for the specified database type(s).""" self._typecasts.reset(typ) - def typecast(self, value: Any, typ: Any) -> Callable | None: + def typecast(self, value: Any, typ: str) -> Any: """Cast the given value according to the given database type.""" if value is None: # for NULL values, no typecast is necessary return None if not isinstance(typ, DbType): - typ = self.get(typ) - if typ: - typ = typ.pgtype + db_type = self.get(typ) + if db_type: + typ = db_type.pgtype cast = self.get_typecast(typ) if typ else None if not cast or cast is str: # no typecast is necessary @@ -1373,6 +1376,7 @@ def getresult(self) -> Any: def __iter__(self) -> Iterator[Any]: return iter(self.result) +# Error messages E = TypeVar('E', bound=DatabaseError) diff --git a/pgdb.py b/pgdb.py index df23bbfd..332ca3d0 100644 --- a/pgdb.py +++ b/pgdb.py @@ -69,7 +69,7 @@ from collections import namedtuple from collections.abc import Iterable from contextlib import suppress -from datetime import date, datetime, time, timedelta +from datetime import date, datetime, time, timedelta, tzinfo from decimal import Decimal as StdDecimal from functools import lru_cache, partial from inspect import signature @@ -78,7 +78,16 @@ from math import isinf, isnan from re import compile as regex from time import localtime -from typing import Callable, ClassVar +from typing import ( + Any, + Callable, + ClassVar, + Generator, + Mapping, + NamedTuple, + Sequence, + TypeVar, +) from uuid import UUID as Uuid # noqa: N811 try: @@ -131,10 +140,15 @@ cast_array, cast_hstore, cast_record, - connect, unescape_bytea, version, ) +from _pg import ( + Connection as Cnx, # base connection +) +from _pg import ( + connect as get_cnx, # get base connection +) __version__ = version @@ -197,7 +211,7 @@ def _timezone_as_offset(tz: str) -> str: return _timezones.get(tz, '+0000') -def decimal_type(decimal_type: type | None = None): +def decimal_type(decimal_type: type | None = None) -> type: """Get or set global type to be used for decimal values. Note that connections cache cast functions. To be sure a global change @@ -212,15 +226,15 @@ def decimal_type(decimal_type: type | None = None): def cast_bool(value: str) -> bool | None: """Cast boolean value in database format to bool.""" - if value: - return value[0] in ('t', 'T') + return value[0] in ('t', 'T') if value else None -def cast_money(value: str) -> Decimal | None: # pyright: ignore +def cast_money(value: str) -> StdDecimal | None: """Cast money value in database format to Decimal.""" - if value: - value = value.replace('(', '-') - return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) + if not value: + return None + value = value.replace('(', '-') + return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) def cast_int2vector(value: str) -> list[int]: @@ -228,7 +242,7 @@ def cast_int2vector(value: str) -> list[int]: return [int(v) for v in value.split()] -def cast_date(value: str, connection) -> date: +def cast_date(value: str, cnx: Cnx) -> date: """Cast a date value.""" # The output format depends on the server setting DateStyle. The default # setting ISO and the setting for German are actually unambiguous. The @@ -244,7 +258,7 @@ def cast_date(value: str, connection) -> date: value = values[0] if len(value) > 10: return date.max - format = connection.date_format() + format = cnx.date_format() return datetime.strptime(value, format).date() @@ -270,7 +284,7 @@ def cast_timetz(value: str) -> time: return datetime.strptime(value, format).timetz() -def cast_timestamp(value: str, connection) -> datetime: +def cast_timestamp(value: str, cnx: Cnx) -> datetime: """Cast a timestamp value.""" if value == '-infinity': return datetime.min @@ -279,7 +293,7 @@ def cast_timestamp(value: str, connection) -> datetime: values = value.split() if values[-1] == 'BC': return datetime.min - format = connection.date_format() + format = cnx.date_format() if format.endswith('-%Y') and len(values) > 2: values = values[1:5] if len(values[3]) > 4: @@ -293,7 +307,7 @@ def cast_timestamp(value: str, connection) -> datetime: return datetime.strptime(' '.join(values), ' '.join(formats)) -def cast_timestamptz(value: str, connection) -> datetime: +def cast_timestamptz(value: str, cnx: Cnx) -> datetime: """Cast a timestamptz value.""" if value == '-infinity': return datetime.min @@ -302,7 +316,7 @@ def cast_timestamptz(value: str, connection) -> datetime: values = value.split() if values[-1] == 'BC': return datetime.min - format = connection.date_format() + format = cnx.date_format() if format.endswith('-%Y') and len(values) > 2: values = values[1:] if len(values[3]) > 4: @@ -439,9 +453,9 @@ class Typecasts(dict): 'int2vector': cast_int2vector, 'uuid': Uuid, 'anyarray': cast_array, 'record': cast_record} - connection = None # will be set in local connection specific instances + cnx: Cnx | None = None # for local connection specific instances - def __missing__(self, typ): + def __missing__(self, typ: str) -> Callable | None: """Create a cast function if it is not cached. Note that this class never raises a KeyError, @@ -464,26 +478,26 @@ def __missing__(self, typ): return cast @staticmethod - def _needs_connection(func): + def _needs_connection(func: Callable) -> bool: """Check if a typecast function needs a connection argument.""" try: args = get_args(func) except (TypeError, ValueError): return False - else: - return 'connection' in args[1:] + return 'cnx' in args[1:] - def _add_connection(self, cast): + def _add_connection(self, cast: Callable) -> Callable: """Add a connection argument to the typecast function if necessary.""" - if not self.connection or not self._needs_connection(cast): + if not self.cnx or not self._needs_connection(cast): return cast - return partial(cast, connection=self.connection) + return partial(cast, cnx=self.cnx) - def get(self, typ, default=None): + def get(self, typ: str, default: Callable | None = None # type: ignore + ) -> Callable | None: """Get the typecast function for the given database type.""" return self[typ] or default - def set(self, typ, cast): + def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: """Set a typecast function for the specified database type(s).""" if isinstance(typ, str): typ = [typ] @@ -498,7 +512,7 @@ def set(self, typ, cast): self[t] = self._add_connection(cast) self.pop(f'_{t}', None) - def reset(self, typ=None): + def reset(self, typ: str | Sequence[str] | None = None) -> None: """Reset the typecasts for the specified type(s) to their defaults. When no type is specified, all typecasts will be reset. @@ -524,20 +538,21 @@ def reset(self, typ=None): self.pop(t, None) self.pop(f'_{t}', None) - def create_array_cast(self, basecast): + def create_array_cast(self, basecast: Callable) -> Callable: """Create an array typecast for the given base cast.""" cast_array = self['anyarray'] - def cast(v): + def cast(v: Any) -> list: return cast_array(v, basecast) return cast - def create_record_cast(self, name, fields, casts): + def create_record_cast(self, name: str, fields: Sequence[str], + casts: Sequence[str]) -> Callable: """Create a named record typecast for the given fields and casts.""" cast_record = self['record'] - record = namedtuple(name, fields) + record = namedtuple(name, fields) # type: ignore - def cast(v): + def cast(v: Any) -> record: # noinspection PyArgumentList return record(*cast_record(v, casts)) return cast @@ -546,12 +561,12 @@ def cast(v): _typecasts = Typecasts() # this is the global typecast dictionary -def get_typecast(typ): - """Get the global typecast function for the given database type(s).""" +def get_typecast(typ: str) -> Callable | None: + """Get the global typecast function for the given database type.""" return _typecasts.get(typ) -def set_typecast(typ, cast): +def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: """Set a global typecast function for the given database type(s). Note that connections cache cast functions. To be sure a global change @@ -560,7 +575,7 @@ def set_typecast(typ, cast): _typecasts.set(typ, cast) -def reset_typecast(typ=None): +def reset_typecast(typ: str | Sequence[str] | None = None) -> None: """Reset the global typecasts for the given type(s) to their default. When no type is specified, all typecasts will be reset. @@ -576,10 +591,11 @@ class LocalTypecasts(Typecasts): defaults = _typecasts - connection = None # will be set in a connection specific instance + cnx: Cnx | None = None # set in connection specific instances - def __missing__(self, typ): + def __missing__(self, typ: str) -> Callable | None: """Create a cast function if it is not cached.""" + cast: Callable | None if typ.startswith('_'): base_cast = self[typ[1:]] cast = self.create_array_cast(base_cast) @@ -594,13 +610,13 @@ def __missing__(self, typ): fields = self.get_fields(typ) if fields: casts = [self[field.type] for field in fields] - fields = [field.name for field in fields] - cast = self.create_record_cast(typ, fields, casts) + field_names = [field.name for field in fields] + cast = self.create_record_cast(typ, field_names, casts) self[typ] = cast return cast # noinspection PyMethodMayBeStatic,PyUnusedLocal - def get_fields(self, typ): + def get_fields(self, typ: str) -> list[FieldInfo]: """Return the fields for the given record type. This method will be replaced with a method that looks up the fields @@ -616,9 +632,17 @@ class TypeCode(str): but carry some additional information. """ + oid: int + len: int + type: str + category: str + delim: str + relid: int + # noinspection PyShadowingBuiltins @classmethod - def create(cls, oid, name, len, type, category, delim, relid): + def create(cls, oid: int, name: str, len: int, type: str, category: str, + delim: str, relid: int) -> TypeCode: """Create a type code for a PostgreSQL data type.""" self = cls(name) self.oid = oid @@ -640,21 +664,22 @@ class TypeCache(dict): important information on the associated database type. """ - def __init__(self, cnx): + def __init__(self, cnx: Cnx) -> None: """Initialize type cache for connection.""" super().__init__() self._escape_string = cnx.escape_string self._src = cnx.source() self._typecasts = LocalTypecasts() - self._typecasts.get_fields = self.get_fields - self._typecasts.connection = cnx + self._typecasts.get_fields = self.get_fields # type: ignore + self._typecasts.cnx = cnx self._query_pg_type = ( "SELECT oid, typname," " typlen, typtype, typcategory, typdelim, typrelid" " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) {}") - def __missing__(self, key): + def __missing__(self, key: int | str) -> TypeCode: """Get the type info from the database if it is not cached.""" + oid: int | str if isinstance(key, int): oid = key else: @@ -677,43 +702,48 @@ def __missing__(self, key): self[type_code.oid] = self[str(type_code)] = type_code return type_code - def get(self, key, default=None): + def get(self, key: int | str, # type: ignore + default: TypeCode | None = None) -> TypeCode | None: """Get the type even if it is not cached.""" try: return self[key] except KeyError: return default - def get_fields(self, typ): + def get_fields(self, typ: int | str | TypeCode) -> list[FieldInfo] | None: """Get the names and types of the fields of composite types.""" - if not isinstance(typ, TypeCode): - typ = self.get(typ) - if not typ: + if isinstance(typ, TypeCode): + relid = typ.relid + else: + type_code = self.get(typ) + if not type_code: return None - if not typ.relid: + relid = type_code.relid + if not relid: return None # this type is not composite self._src.execute( "SELECT attname, atttypid" # noqa: S608 " FROM pg_catalog.pg_attribute" - f" WHERE attrelid OPERATOR(pg_catalog.=) {typ.relid}" + f" WHERE attrelid OPERATOR(pg_catalog.=) {relid}" " AND attnum OPERATOR(pg_catalog.>) 0" " AND NOT attisdropped ORDER BY attnum") return [FieldInfo(name, self.get(int(oid))) for name, oid in self._src.fetch(-1)] - def get_typecast(self, typ): + def get_typecast(self, typ: str) -> Callable | None: """Get the typecast function for the given database type.""" return self._typecasts[typ] - def set_typecast(self, typ, cast): + def set_typecast(self, typ: str | Sequence[str], + cast: Callable | None) -> None: """Set a typecast function for the specified database type(s).""" self._typecasts.set(typ, cast) - def reset_typecast(self, typ=None): + def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: """Reset the typecast function for the specified database type(s).""" self._typecasts.reset(typ) - def typecast(self, value, typ): + def typecast(self, value: Any, typ: str) -> Any: """Cast the given value according to the given database type.""" if value is None: # for NULL values, no typecast is necessary @@ -724,13 +754,13 @@ def typecast(self, value, typ): return value return cast(value) - def get_row_caster(self, types): + def get_row_caster(self, types: Sequence[str]) -> Callable: """Get a typecast function for a complete row of values.""" typecasts = self._typecasts casts = [typecasts[typ] for typ in types] casts = [cast if cast is not str else None for cast in casts] - def row_caster(row): + def row_caster(row: Sequence) -> Sequence: return [value if cast is None or value is None else cast(value) for cast, value in zip(casts, row)] @@ -743,22 +773,26 @@ class _QuoteDict(dict): The quote attribute must be set to the desired quote function. """ - def __getitem__(self, key): + quote: Callable[[str], str] + + def __getitem__(self, key: str) -> str: # noinspection PyUnresolvedReferences return self.quote(super().__getitem__(key)) # *** Error Messages *** +E = TypeVar('E', bound=DatabaseError) + -def _db_error(msg, cls=DatabaseError): +def _db_error(msg: str, cls:type[E] = DatabaseError) -> type[E]: """Return DatabaseError with empty sqlstate attribute.""" error = cls(msg) error.sqlstate = None return error -def _op_error(msg): +def _op_error(msg: str) -> OperationalError: """Return OperationalError.""" return _db_error(msg, OperationalError) @@ -771,16 +805,16 @@ def _op_error(msg): # noinspection PyUnresolvedReferences @lru_cache(maxsize=1024) -def _row_factory(names): +def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: """Get a namedtuple factory for row results with the given names.""" try: - return namedtuple('Row', names, rename=True)._make + return namedtuple('Row', names, rename=True)._make # type: ignore except ValueError: # there is still a problem with the field names names = [f'column_{n}' for n in range(len(names))] - return namedtuple('Row', names)._make + return namedtuple('Row', names)._make # type: ignore -def set_row_factory_size(maxsize): +def set_row_factory_size(maxsize: int | None) -> None: """Change the size of the namedtuple factory cache. If maxsize is set to None, the cache can grow without bound. @@ -795,46 +829,51 @@ def set_row_factory_size(maxsize): class Cursor: """Cursor object.""" - def __init__(self, dbcnx): + def __init__(self, connection: Connection) -> None: """Create a cursor object for the database connection.""" - self.connection = self._dbcnx = dbcnx - self._cnx = dbcnx._cnx - self.type_cache = dbcnx.type_cache + self.connection = self._connection = connection + cnx = connection._cnx + if not cnx: + raise _op_error("Connection has been closed") + self._cnx = cnx + self.type_cache = connection.type_cache self._src = self._cnx.source() # the official attribute for describing the result columns - self._description = None + self._description: list[CursorDescription] | bool | None = None if self.row_factory is Cursor.row_factory: # the row factory needs to be determined dynamically - self.row_factory = None + self.row_factory = None # type: ignore else: - self.build_row_factory = None + self.build_row_factory = None # type: ignore self.rowcount = -1 self.arraysize = 1 self.lastrowid = None - def __iter__(self): + def __iter__(self) -> Cursor: """Make cursor compatible to the iteration protocol.""" return self - def __enter__(self): + def __enter__(self) -> Cursor: """Enter the runtime context for the cursor object.""" return self - def __exit__(self, et, ev, tb): + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: """Exit the runtime context for the cursor object.""" self.close() - def _quote(self, value): + def _quote(self, value: Any) -> Any: """Quote value depending on its type.""" if value is None: return 'NULL' if isinstance(value, (Hstore, Json)): value = str(value) if isinstance(value, (bytes, str)): + cnx = self._cnx if isinstance(value, Binary): - value = self._cnx.escape_bytea(value).decode('ascii') + value = cnx.escape_bytea(value).decode('ascii') else: - value = self._cnx.escape_string(value) + value = cnx.escape_string(value) return f"'{value}'" if isinstance(value, float): if isinf(value): @@ -887,7 +926,8 @@ def _quote(self, value): value = self._quote(value) return value - def _quoteparams(self, string, parameters): + def _quoteparams(self, string: str, + parameters: Mapping | Sequence | None) -> str: """Quote parameters. This function works for both mappings and sequences. @@ -907,12 +947,15 @@ def _quoteparams(self, string, parameters): parameters = tuple(map(self._quote, parameters)) return string % parameters - def _make_description(self, info): + def _make_description(self, info: tuple[int, str, int, int, int] + ) -> CursorDescription: """Make the description tuple for the given field info.""" name, typ, size, mod = info[1:] type_code = self.type_cache[typ] if mod > 0: mod -= 4 + precision: int | None + scale: int | None if type_code == 'numeric': precision, scale = mod >> 16, mod & 0xffff size = precision @@ -922,34 +965,39 @@ def _make_description(self, info): if size == -1: size = mod precision = scale = None - return CursorDescription(name, type_code, - None, size, precision, scale, None) + return CursorDescription( + name, type_code, None, size, precision, scale, None) @property - def description(self): + def description(self) -> list[CursorDescription] | None: """Read-only attribute describing the result columns.""" - descr = self._description - if self._description is True: + description = self._description + if description is None: + return None + if not isinstance(description, list): make = self._make_description - descr = [make(info) for info in self._src.listinfo()] - self._description = descr - return descr + description = [make(info) for info in self._src.listinfo()] + self._description = description + return description @property - def colnames(self): + def colnames(self) -> Sequence[str] | None: """Unofficial convenience method for getting the column names.""" - return [d[0] for d in self.description] + description = self.description + return None if description is None else [d[0] for d in description] @property - def coltypes(self): + def coltypes(self) -> Sequence[TypeCode] | None: """Unofficial convenience method for getting the column types.""" - return [d[1] for d in self.description] + description = self.description + return None if description is None else [d[1] for d in description] - def close(self): + def close(self) -> None: """Close the cursor object.""" self._src.close() - def execute(self, operation, parameters=None): + def execute(self, operation: str, parameters: Sequence | None = None + ) -> Cursor: """Prepare and execute a database operation (query or command).""" # The parameters may also be specified as list of tuples to e.g. # insert multiple rows in a single operation, but this kind of @@ -960,22 +1008,22 @@ def execute(self, operation, parameters=None): and all(isinstance(p, tuple) for p in parameters) and all(len(p) == len(parameters[0]) for p in parameters[1:])): return self.executemany(operation, parameters) - else: - # not a list of tuples - return self.executemany(operation, [parameters]) + # not a list of tuples + return self.executemany(operation, [parameters]) - def executemany(self, operation, seq_of_parameters): + def executemany(self, operation: str, + seq_of_parameters: Sequence[Sequence | None]) -> Cursor: """Prepare operation and execute it against a parameter sequence.""" if not seq_of_parameters: # don't do anything without parameters - return + return self self._description = None self.rowcount = -1 # first try to execute all queries rowcount = 0 sql = "BEGIN" try: - if not self._dbcnx._tnx and not self._dbcnx.autocommit: + if not self._connection._tnx and not self._connection.autocommit: try: self._src.execute(sql) except DatabaseError: @@ -983,7 +1031,7 @@ def executemany(self, operation, seq_of_parameters): except Exception as e: raise _op_error("Can't start transaction") from e else: - self._dbcnx._tnx = True + self._connection._tnx = True for parameters in seq_of_parameters: sql = operation sql = self._quoteparams(sql, parameters) @@ -1005,8 +1053,9 @@ def executemany(self, operation, seq_of_parameters): self._description = True # fetch on demand self.rowcount = self._src.ntuples self.lastrowid = None - if self.build_row_factory: - self.row_factory = self.build_row_factory() + build_row_factory = self.build_row_factory + if build_row_factory: # type: ignore + self.row_factory = build_row_factory() # type: ignore else: self.rowcount = rowcount self.lastrowid = self._src.oidstatus() @@ -1014,7 +1063,7 @@ def executemany(self, operation, seq_of_parameters): # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)" return self - def fetchone(self): + def fetchone(self) -> Sequence | None: """Fetch the next row of a query result set.""" res = self.fetchmany(1, False) try: @@ -1022,11 +1071,12 @@ def fetchone(self): except IndexError: return None - def fetchall(self): + def fetchall(self) -> Sequence[Sequence]: """Fetch all (remaining) rows of a query result.""" return self.fetchmany(-1, False) - def fetchmany(self, size=None, keep=False): + def fetchmany(self, size: int | None = None, keep: bool = False + ) -> Sequence[Sequence]: """Fetch the next set of rows of a query result. The number of rows to fetch per call is specified by the @@ -1046,6 +1096,9 @@ def fetchmany(self, size=None, keep=False): raise _db_error(str(err)) from err row_factory = self.row_factory coltypes = self.coltypes + if coltypes is None: + # cannot determine column types, return raw result + return [row_factory(row) for row in result] if len(result) > 5: # optimize the case where we really fetch many values # by looking up all type casting functions upfront @@ -1055,7 +1108,8 @@ def fetchmany(self, size=None, keep=False): return [row_factory([cast_value(value, typ) for typ, value in zip(coltypes, row)]) for row in result] - def callproc(self, procname, parameters=None): + def callproc(self, procname: str, parameters: Sequence | None = None + ) -> Sequence | None: """Call a stored database procedure with the given name. The sequence of parameters must contain one entry for each input @@ -1073,15 +1127,17 @@ def callproc(self, procname, parameters=None): return parameters # noinspection PyShadowingBuiltins - def copy_from(self, stream, table, - format=None, sep=None, null=None, size=None, columns=None): + def copy_from(self, stream: Any, table: str, + format: str | None = None, sep: str | None = None, + null: str | None = None, size: int | None = None, + columns: Sequence[str] | None = None) -> Cursor: """Copy data from an input stream to the specified table. The input stream can be a file-like object with a read() method or it can also be an iterable returning a row or multiple rows of input on each iteration. - The format must be text, csv or binary. The sep option sets the + The format must be 'text', 'csv' or 'binary'. The sep option sets the column separator (delimiter) used in the non binary formats. The null option sets the textual representation of NULL in the input. @@ -1098,6 +1154,8 @@ def copy_from(self, stream, table, if size: raise ValueError( "Size must only be set for file-like objects") from e + input_type: type | tuple[type, ...] + type_name: str if binary_format: input_type = bytes type_name = 'byte strings' @@ -1116,12 +1174,12 @@ def copy_from(self, stream, table, if not stream.endswith(b'\n'): stream += b'\n' - def chunks(): + def chunks() -> Generator: yield stream elif isinstance(stream, Iterable): - def chunks(): + def chunks() -> Generator: for chunk in stream: if not isinstance(chunk, input_type): raise ValueError( @@ -1143,7 +1201,7 @@ def chunks(): raise TypeError("The size option must be an integer") if size > 0: - def chunks(): + def chunks() -> Generator: while True: buffer = read(size) yield buffer @@ -1152,19 +1210,18 @@ def chunks(): else: - def chunks(): + def chunks() -> Generator: yield read() if not table or not isinstance(table, str): raise TypeError("Need a table to copy to") if table.lower().startswith('select '): raise ValueError("Must specify a table, not a query") - else: - table = '.'.join(map( - self.connection._cnx.escape_identifier, table.split('.', 1))) - operation = [f'copy {table}'] + cnx = self._cnx + table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) + operation_parts = [f'copy {table}'] options = [] - params = [] + parameters = [] if format is not None: if not isinstance(format, str): raise TypeError("The format option must be be a string") @@ -1181,25 +1238,23 @@ def chunks(): raise ValueError( "The sep option must be a single one-byte character") options.append('delimiter %s') - params.append(sep) + parameters.append(sep) if null is not None: if not isinstance(null, str): raise TypeError("The null option must be a string") options.append('null %s') - params.append(null) + parameters.append(null) if columns: if not isinstance(columns, str): - columns = ','.join(map( - self.connection._cnx.escape_identifier, columns)) - operation.append(f'({columns})') - operation.append("from stdin") + columns = ','.join(map(cnx.escape_identifier, columns)) + operation_parts.append(f'({columns})') + operation_parts.append("from stdin") if options: - options = ','.join(options) - operation.append(f'({options})') - operation = ' '.join(operation) + operation_parts.append(f"({','.join(options)})") + operation = ' '.join(operation_parts) putdata = self._src.putdata - self.execute(operation, params) + self.execute(operation, parameters) try: for chunk in chunks(): @@ -1215,8 +1270,10 @@ def chunks(): return self # noinspection PyShadowingBuiltins - def copy_to(self, stream, table, - format=None, sep=None, null=None, decode=None, columns=None): + def copy_to(self, stream: Any, table: str, + format: str | None = None, sep: str | None = None, + null: str | None = None, decode: bool | None = None, + columns: Sequence[str] | None = None) -> Cursor | Generator: """Copy data from the specified table to an output stream. The output stream can be a file-like object with a write() method or @@ -1227,7 +1284,7 @@ def copy_to(self, stream, table, Note that you can also use a select query instead of the table name. - The format must be text, csv or binary. The sep option sets the + The format must be 'text', 'csv' or 'binary'. The sep option sets the column separator (delimiter) used in the non binary formats. The null option sets the textual representation of NULL in the output. @@ -1235,23 +1292,25 @@ def copy_to(self, stream, table, columns are specified, all of them will be copied. """ binary_format = format == 'binary' - if stream is not None: + if stream is None: + write = None + else: try: write = stream.write except AttributeError as e: raise TypeError("Need an output stream to copy to") from e if not table or not isinstance(table, str): raise TypeError("Need a table to copy to") + cnx = self._cnx if table.lower().startswith('select '): if columns: raise ValueError("Columns must be specified in the query") table = f'({table})' else: - table = '.'.join(map( - self.connection._cnx.escape_identifier, table.split('.', 1))) - operation = [f'copy {table}'] + table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) + operation_parts = [f'copy {table}'] options = [] - params = [] + parameters = [] if format is not None: if not isinstance(format, str): raise TypeError("The format option must be a string") @@ -1268,12 +1327,12 @@ def copy_to(self, stream, table, raise ValueError( "The sep option must be a single one-byte character") options.append('delimiter %s') - params.append(sep) + parameters.append(sep) if null is not None: if not isinstance(null, str): raise TypeError("The null option must be a string") options.append('null %s') - params.append(null) + parameters.append(null) if decode is None: decode = format != 'binary' else: @@ -1284,20 +1343,18 @@ def copy_to(self, stream, table, "The decode option is not allowed with binary format") if columns: if not isinstance(columns, str): - columns = ','.join(map( - self.connection._cnx.escape_identifier, columns)) - operation.append(f'({columns})') + columns = ','.join(map(cnx.escape_identifier, columns)) + operation_parts.append(f'({columns})') - operation.append("to stdout") + operation_parts.append("to stdout") if options: - options = ','.join(options) - operation.append(f'({options})') - operation = ' '.join(operation) + operation_parts.append(f"({','.join(options)})") + operation = ' '.join(operation_parts) getdata = self._src.getdata - self.execute(operation, params) + self.execute(operation, parameters) - def copy(): + def copy() -> Generator: self.rowcount = 0 while True: row = getdata(decode) @@ -1308,7 +1365,7 @@ def copy(): self.rowcount += 1 yield row - if stream is None: + if write is None: # no input stream, return the generator return copy() @@ -1320,7 +1377,7 @@ def copy(): # return the cursor object, so you can chain operations return self - def __next__(self): + def __next__(self) -> Sequence: """Return the next row (support for the iteration protocol).""" res = self.fetchone() if res is None: @@ -1332,22 +1389,22 @@ def __next__(self): next = __next__ @staticmethod - def nextset(): + def nextset() -> bool | None: """Not supported.""" raise NotSupportedError("The nextset() method is not supported") @staticmethod - def setinputsizes(sizes): + def setinputsizes(sizes: Sequence[int]) -> None: """Not supported.""" pass # unsupported, but silently passed @staticmethod - def setoutputsize(size, column=0): + def setoutputsize(size: int, column: int = 0) -> None: """Not supported.""" pass # unsupported, but silently passed @staticmethod - def row_factory(row): + def row_factory(row: Sequence) -> Sequence: """Process rows before they are returned. You can overwrite this statically with a custom row factory, or @@ -1367,7 +1424,7 @@ def row_factory(self, row): """ raise NotImplementedError - def build_row_factory(self): + def build_row_factory(self) -> Callable[[Sequence], Sequence] | None: """Build a row factory based on the current description. This implementation builds a row factory for creating named tuples. @@ -1375,8 +1432,7 @@ def build_row_factory(self): different row factories whenever the column description changes. """ names = self.colnames - if names: - return _row_factory(tuple(names)) + return _row_factory(tuple(names)) if names else None CursorDescription = namedtuple('CursorDescription', ( @@ -1401,7 +1457,7 @@ class Connection: DataError = DataError NotSupportedError = NotSupportedError - def __init__(self, cnx): + def __init__(self, cnx: Cnx) -> None: """Create a database connection object.""" self._cnx = cnx # connection self._tnx = False # transaction state @@ -1413,7 +1469,7 @@ def __init__(self, cnx): except Exception as e: raise _op_error("Invalid connection") from e - def __enter__(self): + def __enter__(self) -> Connection: """Enter the runtime context for the connection object. The runtime context can be used for running transactions. @@ -1421,8 +1477,11 @@ def __enter__(self): This also starts a transaction in autocommit mode. """ if self.autocommit: + cnx = self._cnx + if not cnx: + raise _op_error("Connection has been closed") try: - self._cnx.source().execute("BEGIN") + cnx.source().execute("BEGIN") except DatabaseError: raise # database provides error message except Exception as e: @@ -1431,7 +1490,8 @@ def __enter__(self): self._tnx = True return self - def __exit__(self, et, ev, tb): + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: """Exit the runtime context for the connection object. This does not close the connection, but it ends a transaction. @@ -1441,103 +1501,101 @@ def __exit__(self, et, ev, tb): else: self.rollback() - def close(self): + def close(self) -> None: """Close the connection object.""" - if self._cnx: - if self._tnx: - with suppress(DatabaseError): - self.rollback() - self._cnx.close() - self._cnx = None - else: + if not self._cnx: raise _op_error("Connection has been closed") + if self._tnx: + with suppress(DatabaseError): + self.rollback() + self._cnx.close() + self._cnx = None @property - def closed(self): + def closed(self) -> bool: """Check whether the connection has been closed or is broken.""" try: return not self._cnx or self._cnx.status != 1 except TypeError: return True - def commit(self): + def commit(self) -> None: """Commit any pending transaction to the database.""" - if self._cnx: - if self._tnx: - self._tnx = False - try: - self._cnx.source().execute("COMMIT") - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't commit transaction") from e - else: + if not self._cnx: raise _op_error("Connection has been closed") + if self._tnx: + self._tnx = False + try: + self._cnx.source().execute("COMMIT") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise _op_error("Can't commit transaction") from e - def rollback(self): + def rollback(self) -> None: """Roll back to the start of any pending transaction.""" - if self._cnx: - if self._tnx: - self._tnx = False - try: - self._cnx.source().execute("ROLLBACK") - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't rollback transaction") from e - else: + if not self._cnx: raise _op_error("Connection has been closed") - - def cursor(self): - """Return a new cursor object using the connection.""" - if self._cnx: + if self._tnx: + self._tnx = False try: - return self.cursor_type(self) + self._cnx.source().execute("ROLLBACK") + except DatabaseError: + raise # database provides error message except Exception as e: - raise _op_error("Invalid connection") from e - else: + raise _op_error("Can't rollback transaction") from e + + def cursor(self) -> Cursor: + """Return a new cursor object using the connection.""" + if not self._cnx: raise _op_error("Connection has been closed") + try: + return self.cursor_type(self) + except Exception as e: + raise _op_error("Invalid connection") from e if shortcutmethods: # otherwise do not implement and document this - def execute(self, operation, params=None): + def execute(self, operation: str, + parameters: Sequence | None = None) -> Cursor: """Shortcut method to run an operation on an implicit cursor.""" cursor = self.cursor() - cursor.execute(operation, params) + cursor.execute(operation, parameters) return cursor - def executemany(self, operation, param_seq): + def executemany(self, operation: str, + seq_of_parameters: Sequence[Sequence | None] + ) -> Cursor: """Shortcut method to run an operation against a sequence.""" cursor = self.cursor() - cursor.executemany(operation, param_seq) + cursor.executemany(operation, seq_of_parameters) return cursor # *** Module Interface *** -_connect = connect - - -def connect(dsn=None, - user=None, password=None, - host=None, database=None, **kwargs): +def connect(dsn: str | None = None, + user: str | None = None, password: str | None = None, + host: str | None = None, database: str | None = None, + **kwargs: Any) -> Connection: """Connect to a database.""" # first get params from DSN dbport = -1 - dbhost = "" - dbname = "" - dbuser = "" - dbpasswd = "" - dbopt = "" - try: - params = dsn.split(":") - dbhost = params[0] - dbname = params[1] - dbuser = params[2] - dbpasswd = params[3] - dbopt = params[4] - except (AttributeError, IndexError, TypeError): - pass + dbhost: str | None = "" + dbname: str | None = "" + dbuser: str | None = "" + dbpasswd: str | None = "" + dbopt: str | None = "" + if dsn: + try: + params = dsn.split(":", 4) + dbhost = params[0] + dbname = params[1] + dbuser = params[2] + dbpasswd = params[3] + dbopt = params[4] + except (AttributeError, IndexError, TypeError): + pass # override if necessary if user is not None: @@ -1546,9 +1604,9 @@ def connect(dsn=None, dbpasswd = password if database is not None: dbname = database - if host is not None: + if host: try: - params = host.split(":") + params = host.split(":", 1) dbhost = params[0] dbport = int(params[1]) except (AttributeError, IndexError, TypeError, ValueError): @@ -1562,22 +1620,21 @@ def connect(dsn=None, # pass keyword arguments as connection info string if kwargs: - kwargs = list(kwargs.items()) - if '=' in dbname: - dbname = [dbname] + kwarg_list = list(kwargs.items()) + kw_parts = [] + if dbname and '=' in dbname: + kw_parts.append(dbname) else: - kwargs.insert(0, ('dbname', dbname)) - dbname = [] - for kw, value in kwargs: + kwarg_list.insert(0, ('dbname', dbname)) + for kw, value in kwarg_list: value = str(value) if not value or ' ' in value: value = value.replace('\\', '\\\\').replace("'", "\\'") value = f"'{value}'" - dbname.append(f'{kw}={value}') - dbname = ' '.join(dbname) + kw_parts.append(f'{kw}={value}') + dbname = ' '.join(kw_parts) # open the connection - # noinspection PyArgumentList - cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) + cnx = get_cnx(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) return Connection(cnx) @@ -1590,67 +1647,61 @@ class DbType(frozenset): We must thus use type names as internal type codes. """ - def __new__(cls, values): + def __new__(cls, values: str | Iterable[str]) -> DbType: """Create new type object.""" if isinstance(values, str): values = values.split() - return super().__new__(cls, values) + return super().__new__(cls, values) # type: ignore - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Check whether types are considered equal.""" if isinstance(other, str): if other.startswith('_'): other = other[1:] return other in self - else: - return super().__eq__(other) + return super().__eq__(other) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Check whether types are not considered equal.""" if isinstance(other, str): if other.startswith('_'): other = other[1:] return other not in self - else: - return super().__ne__(other) + return super().__ne__(other) class ArrayType: """Type class for PostgreSQL array types.""" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, str): return other.startswith('_') - else: - return isinstance(other, ArrayType) + return isinstance(other, ArrayType) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: if isinstance(other, str): return not other.startswith('_') - else: - return not isinstance(other, ArrayType) + return not isinstance(other, ArrayType) class RecordType: """Type class for PostgreSQL record types.""" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, TypeCode): # noinspection PyUnresolvedReferences return other.type == 'c' - elif isinstance(other, str): + if isinstance(other, str): return other == 'record' - else: - return isinstance(other, RecordType) + return isinstance(other, RecordType) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: if isinstance(other, TypeCode): # noinspection PyUnresolvedReferences return other.type != 'c' - elif isinstance(other, str): + if isinstance(other, str): return other != 'record' - else: - return not isinstance(other, RecordType) + return not isinstance(other, RecordType) # Mandatory type objects defined by DB-API 2 specs: @@ -1691,35 +1742,38 @@ def __ne__(self, other): # Mandatory type helpers defined by DB-API 2 specs: -def Date(year, month, day): # noqa: N802 +def Date(year: int, month: int, day: int) -> date: # noqa: N802 """Construct an object holding a date value.""" return date(year, month, day) -def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None): # noqa: N802 +def Time(hour: int, minute: int = 0, # noqa: N802 + second: int = 0, microsecond: int = 0, + tzinfo: tzinfo | None = None) -> time: """Construct an object holding a time value.""" return time(hour, minute, second, microsecond, tzinfo) -def Timestamp(year, month, day, # noqa: N802 - hour=0, minute=0, second=0, microsecond=0, - tzinfo=None): +def Timestamp(year: int, month: int, day: int, # noqa: N802 + hour: int = 0, minute: int = 0, + second: int = 0, microsecond: int = 0, + tzinfo: tzinfo | None = None) -> datetime: """Construct an object holding a time stamp value.""" - return datetime(year, month, day, hour, minute, second, microsecond, - tzinfo) + return datetime(year, month, day, hour, minute, + second, microsecond, tzinfo) -def DateFromTicks(ticks): # noqa: N802 +def DateFromTicks(ticks: float | None) -> date: # noqa: N802 """Construct an object holding a date value from the given ticks value.""" return Date(*localtime(ticks)[:3]) -def TimeFromTicks(ticks): # noqa: N802 +def TimeFromTicks(ticks: float | None) -> time: # noqa: N802 """Construct an object holding a time value from the given ticks value.""" return Time(*localtime(ticks)[3:6]) -def TimestampFromTicks(ticks): # noqa: N802 +def TimestampFromTicks(ticks: float | None) -> datetime: # noqa: N802 """Construct an object holding a time stamp from the given ticks value.""" return Timestamp(*localtime(ticks)[:6]) @@ -1730,11 +1784,13 @@ class Binary(bytes): # Additional type helpers for PyGreSQL: -def Interval(days, # noqa: N802 - hours=0, minutes=0, seconds=0, microseconds=0): +def Interval(days: int | float, # noqa: N802 + hours: int | float = 0, minutes: int | float = 0, + seconds: int | float = 0, microseconds: int | float = 0 + ) -> timedelta: """Construct an object holding a time interval value.""" - return timedelta(days, hours=hours, minutes=minutes, seconds=seconds, - microseconds=microseconds) + return timedelta(days, hours=hours, minutes=minutes, + seconds=seconds, microseconds=microseconds) Uuid = Uuid # Construct an object holding a UUID value @@ -1747,7 +1803,7 @@ class Hstore(dict): _re_escape = regex(r'(["\\])') @classmethod - def _quote(cls, s): + def _quote(cls, s: Any) -> Any: if s is None: return 'NULL' if not isinstance(s, str): @@ -1760,7 +1816,7 @@ def _quote(cls, s): s = f'"{s}"' return s - def __str__(self): + def __str__(self) -> str: """Create a printable representation of the hstore value.""" q = self._quote return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) @@ -1769,12 +1825,13 @@ def __str__(self): class Json: """Construct a wrapper for holding an object serializable to JSON.""" - def __init__(self, obj, encode=None): + def __init__(self, obj: Any, + encode: Callable[[Any], str] | None = None) -> None: """Initialize the JSON object.""" self.obj = obj self.encode = encode or jsonencode - def __str__(self): + def __str__(self) -> str: """Create a printable representation of the JSON object.""" obj = self.obj if isinstance(obj, str): @@ -1785,11 +1842,11 @@ def __str__(self): class Literal: """Construct a wrapper for holding a literal SQL string.""" - def __init__(self, sql): + def __init__(self, sql: str) -> None: """Initialize literal SQL string.""" self.sql = sql - def __str__(self): + def __str__(self) -> str: """Return a printable representation of the SQL string.""" return self.sql diff --git a/pgsource.c b/pgsource.c index 73c9a52b..9bc6bb4a 100644 --- a/pgsource.c +++ b/pgsource.c @@ -680,7 +680,8 @@ _source_buildinfo(sourceObject *self, int num) /* Lists fields info. */ static char source_listinfo__doc__[] = - "listinfo() -- get information for all fields (position, name, type oid)"; + "listinfo() -- get information for all fields" + " (position, name, type oid, size, type modifier)"; static PyObject * source_listInfo(sourceObject *self, PyObject *noargs) diff --git a/tests/dbapi20.py b/tests/dbapi20.py index f72d99f7..0c038f72 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -12,7 +12,7 @@ import time import unittest from contextlib import suppress -from typing import Any, Mapping +from typing import Any, ClassVar __version__ = '1.15.0' @@ -22,7 +22,7 @@ class DatabaseAPI20Test(unittest.TestCase): This implementation tests Gadfly, but the TestCase is structured so that other self.drivers can subclass this test case to ensure compliance with the DB-API. It is - expected that this TestCase may be expanded in the future + expected that this TestCase may be expanded i qn the future if ambiguities or edge conditions are discovered. The 'Optional Extensions' are not yet being tested. @@ -43,7 +43,7 @@ class mytest(dbapi20.DatabaseAPI20Test): # method is to be found driver: Any = None connect_args: tuple = () # List of arguments to pass to connect - connect_kw_args: Mapping[str, Any] = {} # Keyword arguments for connect + connect_kw_args: ClassVar[dict[str, Any]] = {} # Keyword arguments table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables ddl1 = f'create table {table_prefix}booze (name varchar(20))' diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 31aec400..71438f71 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -23,7 +23,7 @@ from io import StringIO from operator import itemgetter from time import strftime -from typing import Any, ClassVar +from typing import Any, Callable, ClassVar from uuid import UUID import pg # the module under test @@ -4910,7 +4910,7 @@ def test_debug_multiple_args(self): class TestMemoryLeaks(unittest.TestCase): """Test that the DB class does not leak memory.""" - def get_leaks(self, fut): + def get_leaks(self, fut: Callable): ids: set = set() objs: list = [] add_ids = ids.update diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 657e820c..6838d03a 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1,9 +1,11 @@ #!/usr/bin/python +from __future__ import annotations + import gc import unittest from datetime import date, datetime, time, timedelta, timezone -from typing import Any, Mapping +from typing import Any, ClassVar from uuid import UUID as Uuid # noqa: N811 import pgdb @@ -26,7 +28,7 @@ class TestPgDb(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () - connect_kw_args: Mapping[str, Any] = { + connect_kw_args: ClassVar[dict[str, Any]] = { 'database': dbname, 'host': f"{dbhost or ''}:{dbport or -1}", 'user': dbuser, 'password': dbpasswd} @@ -159,8 +161,10 @@ def test_row_factory(self): class TestCursor(pgdb.Cursor): def row_factory(self, row): + description = self.description + assert isinstance(description, list) return {f'column {desc[0]}': value - for desc, value in zip(self.description, row)} + for desc, value in zip(description, row)} con = self._connect() con.cursor_type = TestCursor @@ -186,7 +190,9 @@ def test_build_row_factory(self): class TestCursor(pgdb.Cursor): def build_row_factory(self): - keys = [desc[0] for desc in self.description] + description = self.description + assert isinstance(description, list) + keys = [desc[0] for desc in description] return lambda row: { key: value for key, value in zip(keys, row)} @@ -566,19 +572,37 @@ def test_float(self): inval = -inf elif inval in ('nan', 'NaN'): inval = nan - if isinf(inval): + if isinf(inval): # type: ignore self.assertTrue(isinf(outval)) - if inval < 0: + if inval < 0: # type: ignore self.assertTrue(outval < 0) else: self.assertTrue(outval > 0) - elif isnan(inval): + elif isnan(inval): # type: ignore self.assertTrue(isnan(outval)) else: self.assertEqual(inval, outval) def test_datetime(self): dt = datetime(2011, 7, 17, 15, 47, 42, 317509) + values = [dt.date(), dt.time(), dt, dt.time(), dt] + assert isinstance(values[3], time) + values[3] = values[3].replace(tzinfo=timezone.utc) + assert isinstance(values[4], datetime) + values[4] = values[4].replace(tzinfo=timezone.utc) + d = (dt.year, dt.month, dt.day) + t = (dt.hour, dt.minute, dt.second, dt.microsecond) + z = (timezone.utc,) + inputs = [ + # input as objects + values, + # input as text + [v.isoformat() for v in values], # type: ignore + # # input using type helpers + [pgdb.Date(*d), pgdb.Time(*t), + pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), + pgdb.Timestamp(*(d + t + z))] + ] table = self.table_prefix + 'booze' con = self._connect() try: @@ -587,26 +611,11 @@ def test_datetime(self): cur.execute(f"create table {table} (" "d date, t time, ts timestamp," "tz timetz, tsz timestamptz)") - for n in range(3): - values = [dt.date(), dt.time(), dt, dt.time(), dt] - values[3] = values[3].replace(tzinfo=timezone.utc) - values[4] = values[4].replace(tzinfo=timezone.utc) - if n == 0: # input as objects - params = values - if n == 1: # input as text - params = [v.isoformat() for v in values] # as text - elif n == 2: # input using type helpers - d = (dt.year, dt.month, dt.day) - t = (dt.hour, dt.minute, dt.second, dt.microsecond) - z = (timezone.utc,) - params = [pgdb.Date(*d), pgdb.Time(*t), - pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), - pgdb.Timestamp(*(d + t + z))] + for params in inputs: for datestyle in ('iso', 'postgres, mdy', 'postgres, dmy', 'sql, mdy', 'sql, dmy', 'german'): cur.execute(f"set datestyle to {datestyle}") - if n != 1: - # noinspection PyUnboundLocalVariable + if not isinstance(params[0], str): cur.execute("select %s,%s,%s,%s,%s", params) row = cur.fetchone() self.assertEqual(row, tuple(values)) @@ -615,11 +624,13 @@ def test_datetime(self): " values (%s,%s,%s,%s,%s)", params) cur.execute(f"select * from {table}") d = cur.description + assert isinstance(d, list) for i in range(5): - self.assertEqual(d[i].type_code, pgdb.DATETIME) - self.assertNotEqual(d[i].type_code, pgdb.STRING) - self.assertNotEqual(d[i].type_code, pgdb.ARRAY) - self.assertNotEqual(d[i].type_code, pgdb.RECORD) + tc = d[i].type_code + self.assertEqual(tc, pgdb.DATETIME) + self.assertNotEqual(tc, pgdb.STRING) + self.assertNotEqual(tc, pgdb.ARRAY) + self.assertNotEqual(tc, pgdb.RECORD) self.assertEqual(d[0].type_code, pgdb.DATE) self.assertEqual(d[1].type_code, pgdb.TIME) self.assertEqual(d[2].type_code, pgdb.TIMESTAMP) @@ -633,20 +644,20 @@ def test_datetime(self): def test_interval(self): td = datetime(2011, 7, 17, 15, 47, 42, 317509) - datetime(1970, 1, 1) + inputs = [ + # input as objects + td, + # input as text + f'{td.days} days {td.seconds} seconds' + f' {td.microseconds} microseconds', + # input using type helpers + pgdb.Interval(td.days, 0, 0, td.seconds, td.microseconds)] table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() cur.execute(f"create table {table} (i interval)") - for n in range(3): - if n == 0: # input as objects - param = td - if n == 1: # input as text - param = (f'{td.days} days {td.seconds} seconds' - f' {td.microseconds} microseconds') - elif n == 2: # input using type helpers - param = pgdb.Interval( - td.days, 0, 0, td.seconds, td.microseconds) + for param in inputs: for intervalstyle in ('sql_standard ', 'postgres', 'postgres_verbose', 'iso_8601'): cur.execute(f"set intervalstyle to {intervalstyle}") @@ -705,7 +716,7 @@ def test_uuid(self): self.assertEqual(result, d) def test_insert_array(self): - values = [ + values: list[tuple[Any, Any]] = [ (None, None), ([], []), ([None], [[None], ['null']]), ([1, 2, 3], [['a', 'b'], ['c', 'd']]), ([20000, 25000, 25000, 30000], @@ -819,15 +830,15 @@ def test_select_record(self): def test_custom_type(self): values = [3, 5, 65] - values = list(map(PgBitString, values)) + values = list(map(PgBitString, values)) # type: ignore table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() - params = enumerate(values) # params have __pg_repr__ method + seq_params = enumerate(values) # params have __pg_repr__ method cur.execute( f'create table "{table}" (n smallint, b bit varying(7))') - cur.executemany(f"insert into {table} values (%s,%s)", params) + cur.executemany(f"insert into {table} values (%s,%s)", seq_params) cur.execute(f"select * from {table}") rows = cur.fetchall() finally: @@ -850,20 +861,29 @@ def test_set_decimal_type(self): try: cur = con.cursor() # change decimal type globally to int - int_type = lambda v: int(float(v)) # noqa: E731 - self.assertTrue(pgdb.decimal_type(int_type) is int_type) + + class CustomDecimal(str): + + def __init__(self, value: Any) -> None: + self.value = value + + def __str__(self) -> str: + return str(self.value).replace('.', ',') + + self.assertTrue(pgdb.decimal_type(CustomDecimal) is CustomDecimal) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] - self.assertTrue(isinstance(value, int)) - self.assertEqual(value, 4) + self.assertTrue(isinstance(value, CustomDecimal)) + self.assertEqual(str(value), '4,25') # change decimal type again to float self.assertTrue(pgdb.decimal_type(float) is float) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] # the connection still uses the old setting - self.assertTrue(isinstance(value, int)) + self.assertTrue(isinstance(value, str)) + self.assertEqual(str(value), '4,25') # bust the cache for type functions for the connection con.type_cache.reset_typecast() cur.execute('select 4.25') @@ -1352,8 +1372,8 @@ def test_set_row_factory_size(self): info.hits, 0 if maxsize is not None and maxsize < 2 else 4) def test_memory_leaks(self): - ids = set() - objs = [] + ids: set = set() + objs: list = [] add_ids = ids.update gc.collect() objs[:] = gc.get_objects() diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 09211718..02810ba6 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -14,7 +14,7 @@ import unittest from collections.abc import Iterable from contextlib import suppress -from typing import Sequence +from typing import ClassVar import pgdb # the module under test @@ -101,6 +101,11 @@ class TestCopy(unittest.TestCase): cls_set_up = False + data: ClassVar[list[tuple[int, str]]] = [ + (1935, 'Luciano Pavarotti'), + (1941, 'Plácido Domingo'), + (1946, 'José Carreras')] + @staticmethod def connect(): host = f"{dbhost or ''}:{dbport or -1}" @@ -122,8 +127,9 @@ def setUpClass(cls): cur.execute("set client_encoding=utf8") cur.execute("select 'Plácido and José'").fetchone() except (pgdb.DataError, pgdb.NotSupportedError): - cls.data[1] = (1941, 'Plaacido Domingo') - cls.data[2] = (1946, 'Josee Carreras') + cls.data[1:3] = [ + (1941, 'Plaacido Domingo'), + (1946, 'Josee Carreras')] cls.can_encode = False cur.close() con.close() @@ -152,11 +158,6 @@ def tearDown(self): with suppress(Exception): self.con.close() - data: Sequence[tuple[int, str]] = [ - (1935, 'Luciano Pavarotti'), - (1941, 'Plácido Domingo'), - (1946, 'José Carreras')] - can_encode = True @property @@ -405,9 +406,9 @@ def test_generator(self): self.assertIsInstance(ret, Iterable) rows = list(ret) self.assertEqual(len(rows), 3) - rows = ''.join(rows) - self.assertIsInstance(rows, str) - self.assertEqual(rows, self.data_text) + text = ''.join(rows) + self.assertIsInstance(text, str) + self.assertEqual(text, self.data_text) self.check_rowcount() def test_generator_with_schema_name(self): @@ -419,9 +420,9 @@ def test_generator_bytes(self): self.assertIsInstance(ret, Iterable) rows = list(ret) self.assertEqual(len(rows), 3) - rows = b''.join(rows) - self.assertIsInstance(rows, bytes) - self.assertEqual(rows, self.data_text.encode()) + byte_text = b''.join(rows) + self.assertIsInstance(byte_text, bytes) + self.assertEqual(byte_text, self.data_text.encode()) def test_rowcount_increment(self): ret = self.copy_to() @@ -477,9 +478,9 @@ def test_csv(self): self.assertIsInstance(ret, Iterable) rows = list(ret) self.assertEqual(len(rows), 3) - rows = ''.join(rows) - self.assertIsInstance(rows, str) - self.assertEqual(rows, self.data_csv) + csv = ''.join(rows) + self.assertIsInstance(csv, str) + self.assertEqual(csv, self.data_csv) self.check_rowcount(3) def test_csv_with_sep(self): diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index c28fbefc..c09d13b8 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -139,7 +139,7 @@ def test_all_steps(self): cursor.executemany("insert into fruits (name) values (%s)", parameters) con.commit() cursor.execute('select * from fruits where id=1') - r = cursor.fetchone() + r: Any = cursor.fetchone() self.assertIsInstance(r, tuple) self.assertEqual(len(r), 2) r = str(r) From 18bb347aeb5dd618b7c8877d090adc086ffa8e5f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 4 Sep 2023 08:47:02 +0200 Subject: [PATCH 151/194] Do not use OrderedDict anymore --- docs/conf.py | 2 +- docs/contents/pg/db_wrapper.rst | 9 +++---- docs/contents/pgdb/cursor.rst | 4 +-- docs/contents/tutorial.rst | 9 ++----- pg.py | 44 ++++++++++++++++----------------- tox.ini | 4 +-- 6 files changed, 33 insertions(+), 39 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 0f95ab1b..9dd604f2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -41,7 +41,7 @@ 'list', 'object', 'set', 'str', 'tuple', 'False', 'True', 'None', 'namedtuple', 'namedtuples', - 'OrderedDict', 'decimal.Decimal', + 'decimal.Decimal', 'bytes/str', 'list of namedtuples', 'tuple of callables', 'first field', 'type of first field', 'Notice', 'DATETIME'), diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index ea4f71c1..1dbd18ef 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -823,7 +823,7 @@ has only one column anyway. :param int offset: number of rows to be skipped (the OFFSET clause) :param bool scalar: whether only the first column shall be returned :returns: the content of the table as a list - :rtype: dict or OrderedDict + :rtype: dict :raises TypeError: the table name has not been specified :raises KeyError: keyname(s) are invalid or not part of the result :raises pg.ProgrammingError: no keyname(s) and table has no primary key @@ -837,10 +837,9 @@ The rows will be also named tuples unless the *scalar* option has been set to *True*. With the optional parameter *keyname* you can specify a different set of columns to be used as the keys of the dictionary. -If the Python version supports it, the dictionary will be an *OrderedDict* -using the order specified with the *order* parameter or the key column(s) -if not specified. You can set *order* to *False* if you don't care about the -ordering. In this case the returned dictionary will be an ordinary one. +The dictionary will be ordered using the order specified with the *order* +parameter or the key column(s) if not specified. You can set *order* to +*False* if you don't care about the ordering. .. versionadded:: 5.0 diff --git a/docs/contents/pgdb/cursor.rst b/docs/contents/pgdb/cursor.rst index e1ed8b0f..72473057 100644 --- a/docs/contents/pgdb/cursor.rst +++ b/docs/contents/pgdb/cursor.rst @@ -340,8 +340,8 @@ be used for all result sets. If you overwrite this method, the method will be ignored. Note that named tuples are very efficient and can be easily converted to -dicts (even OrderedDicts) by calling ``row._asdict()``. If you still want -to return rows as dicts, you can create a custom cursor class like this:: +dicts by calling ``row._asdict()``. If you still want to return rows as dicts, +you can create a custom cursor class like this:: class DictCursor(pgdb.Cursor): diff --git a/docs/contents/tutorial.rst b/docs/contents/tutorial.rst index 15577ad3..79273c7c 100644 --- a/docs/contents/tutorial.rst +++ b/docs/contents/tutorial.rst @@ -117,13 +117,8 @@ Using the method :meth:`DB.get_as_dict`, you can easily import the whole table into a Python dictionary mapping the primary key *id* to the *name*:: >>> db.get_as_dict('fruits', scalar=True) - OrderedDict([(1, 'apple'), - (2, 'banana'), - (3, 'cherimaya'), - (4, 'durian'), - (5, 'eggfruit'), - (6, 'fig'), - (7, 'grapefruit')]) + {1: 'apple', 2: 'banana', 3: 'cherimaya', 4: 'durian', 5: 'eggfruit', + 6: 'fig', 7: 'grapefruit', 8: 'apple', 9: 'banana'} To change a single row in the database, you can use the :meth:`DB.update` method. For instance, if you want to capitalize the name 'banana':: diff --git a/pg.py b/pg.py index 45f8ae46..75c0b32c 100644 --- a/pg.py +++ b/pg.py @@ -2638,12 +2638,13 @@ def truncate(self, table: str | list[str] | tuple[str, ...] | self._do_debug(cmd) return self._valid_db.query(cmd) - def get_as_list(self, table: str, - what: str | list[str] | tuple[str, ...] | None = None, - where: str | list[str] | tuple[str, ...] | None = None, - order: str | list[str] | tuple[str, ...] | None = None, - limit: int | None = None, offset: int | None = None, - scalar: bool = False) -> list: + def get_as_list( + self, table: str, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | bool | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> list: """Get a table as a list. This gets a convenient representation of the table as a list @@ -2686,13 +2687,13 @@ def get_as_list(self, table: str, if isinstance(where, (list, tuple)): where = ' AND '.join(map(str, where)) cmd_parts.extend(['WHERE', where]) - if order is None: + if order is None or order is True: try: order = self.pkeys(table) except (KeyError, ProgrammingError): with suppress(KeyError, ProgrammingError): order = list(self.get_attnames(table)) - if order: + if order and not isinstance(order, bool): if isinstance(order, (list, tuple)): order = ', '.join(map(str, order)) cmd_parts.extend(['ORDER BY', order]) @@ -2708,13 +2709,14 @@ def get_as_list(self, table: str, res = [row[0] for row in res] return res - def get_as_dict(self, table: str, - keyname: str | list[str] | tuple[str, ...] | None = None, - what: str | list[str] | tuple[str, ...] | None = None, - where: str | list[str] | tuple[str, ...] | None = None, - order: str | list[str] | tuple[str, ...] | None = None, - limit: int | None = None, offset: int | None = None, - scalar: bool = False) -> dict: + def get_as_dict( + self, table: str, + keyname: str | list[str] | tuple[str, ...] | None = None, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | bool | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> dict: """Get a table as a dictionary. This method is similar to get_as_list(), but returns the table @@ -2728,11 +2730,9 @@ def get_as_dict(self, table: str, set of columns to be used as the keys of the dictionary. It must be set as a string, list or a tuple. - If the Python version supports it, the dictionary will be an - dict using the order specified with the 'order' parameter - or the key column(s) if not specified. You can set 'order' to False - if you don't care about the ordering. In this case the returned - dictionary will be an ordinary one. + The dictionary will be ordered using the order specified with the + 'order' parameter or the key column(s) if not specified. You can + set 'order' to False if you don't care about the ordering. """ if not table: raise TypeError('The table name is missing') @@ -2759,9 +2759,9 @@ def get_as_dict(self, table: str, if isinstance(where, (list, tuple)): where = ' AND '.join(map(str, where)) cmd_parts.extend(['WHERE', where]) - if order is None: + if order is None or order is True: order = keyname - if order: + if order and not isinstance(order, bool): if isinstance(order, (list, tuple)): order = ', '.join(map(str, order)) cmd_parts.extend(['ORDER BY', order]) diff --git a/tox.ini b/tox.ini index 7e52747d..322a3f32 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11},ruff,cformat,docs +envlist = py3{7,8,9,10,11},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.11 @@ -13,7 +13,7 @@ commands = basepython = python3.11 deps = mypy>=1.5.1 commands = - mypy setup.py pg.py pgdb.py tests + mypy pg.py pgdb.py tests [testenv:cformat] basepython = python3.11 From 2ccc937ef018dac9c7f4ebd2c2ae2a81b808c254 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 4 Sep 2023 12:51:06 +0200 Subject: [PATCH 152/194] Add back setup keywords in setup.py This brings some duplication with pyproject.toml but it avoids missing metadata and Python modules when using legacy install. Unfortunately, due to the C extension, we cannot completely get rid of setup.py yet. --- pg.py | 2 +- setup.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/pg.py b/pg.py index 75c0b32c..c7a34e1b 100644 --- a/pg.py +++ b/pg.py @@ -164,7 +164,7 @@ 'InvalidResultError', 'MultipleResultsError', 'NoResultError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', - 'Query', + 'Connection', 'Query', 'INV_READ', 'INV_WRITE', 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', diff --git a/setup.py b/setup.py index c20c9607..3ad3a906 100755 --- a/setup.py +++ b/setup.py @@ -25,11 +25,15 @@ raise Exception( f"Sorry, PyGreSQL {version} does not support this Python version") +with open('README.rst') as f: + long_description = f.read() + # For historical reasons, PyGreSQL does not install itself as a single # "pygresql" package, but as two top-level modules "pg", providing the # classic interface, and "pgdb" for the modern DB-API 2.0 interface. # These two top-level Python modules share the same C extension "_pg". +py_modules = ['pg', 'pgdb'] c_sources = ['pgmodule.c'] def pg_config(s): @@ -125,6 +129,40 @@ def finalize_options(self): setup( name="PyGreSQL", version=version, + description="Python PostgreSQL Interfaces", + long_description=long_description, + long_description_content_type='text/x-rst', + keywords="pygresql postgresql database api dbapi", + author="D'Arcy J. M. Cain", + author_email="darcy@PyGreSQL.org", + url="https://pygresql.github.io/", + download_url="https://pygresql.github.io/contents/download/", + project_urls={ + "Documentation": "https://pygresql.github.io/contents/", + "Issue Tracker": "https://github.com/PyGreSQL/PyGreSQL/issues/", + "Mailing List": "https://mail.vex.net/mailman/listinfo/pygresql", + "Source Code": "https://github.com/PyGreSQL/PyGreSQL"}, + classifiers=[ + "Development Status :: 6 - Mature", + "Intended Audience :: Developers", + "License :: OSI Approved :: PostgreSQL License", + "Operating System :: OS Independent", + "Programming Language :: C", + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development :: Libraries :: Python Modules"], + license="PostgreSQL", + py_modules=py_modules, + test_suite='tests.discover', + zip_safe=False, ext_modules=[Extension( '_pg', c_sources, include_dirs=include_dirs, library_dirs=library_dirs, From 4d9103e61d9d96e7e99902307014846c390dace9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 4 Sep 2023 23:42:40 +0200 Subject: [PATCH 153/194] Make inline type hints available and add stub file For this to work it was necessary to convert the two single modules to packages that can hold the pg.typed hint and stub file. The advantage of the stub file is that we now have proper type hints also for the objects imported from the C extension module. --- .github/workflows/lint.yml | 2 +- MANIFEST.in | 16 +- docs/contents/install.rst | 9 +- docs/contents/pg/connection.rst | 4 +- docs/contents/pg/large_objects.rst | 4 +- docs/contents/pg/module.rst | 10 +- docs/download/files.rst | 29 +- pgconn.c => ext/pgconn.c | 0 pginternal.c => ext/pginternal.c | 0 pglarge.c => ext/pglarge.c | 0 pgmodule.c => ext/pgmodule.c | 2 +- pgnotice.c => ext/pgnotice.c | 0 pgquery.c => ext/pgquery.c | 0 pgsource.c => ext/pgsource.c | 4 +- pgtypes.h => ext/pgtypes.h | 0 pg.py => pg/__init__.py | 110 ++--- pg/_pg.pyi | 635 +++++++++++++++++++++++++++++ pg/py.typed | 4 + pgdb.py => pgdb/__init__.py | 88 ++-- pgdb/py.typed | 1 + pyproject.toml | 6 +- setup.py | 46 +-- tests/config.py | 5 +- tests/test_classic_connection.py | 78 ++-- tests/test_classic_dbwrapper.py | 36 +- tests/test_classic_functions.py | 48 +-- tests/test_dbapi20.py | 2 +- tox.ini | 6 +- 28 files changed, 888 insertions(+), 257 deletions(-) rename pgconn.c => ext/pgconn.c (100%) rename pginternal.c => ext/pginternal.c (100%) rename pglarge.c => ext/pglarge.c (100%) rename pgmodule.c => ext/pgmodule.c (99%) rename pgnotice.c => ext/pgnotice.c (100%) rename pgquery.c => ext/pgquery.c (100%) rename pgsource.c => ext/pgsource.c (99%) rename pgtypes.h => ext/pgtypes.h (100%) rename pg.py => pg/__init__.py (97%) create mode 100644 pg/_pg.pyi create mode 100644 pg/py.typed rename pgdb.py => pgdb/__init__.py (97%) create mode 100644 pgdb/py.typed diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 40f5299e..dad89096 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,5 +22,5 @@ jobs: with: python-version: 3.11 - name: Run quality checks - run: tox -e ruff,docs + run: tox -e ruff,mypy,cformat,docs timeout-minutes: 5 diff --git a/MANIFEST.in b/MANIFEST.in index e6e9e5a9..4ff1c2b6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,14 +1,20 @@ -include *.c -include *.h -include *.py +include setup.py + +recursive-include pg pgdb tests *.py + +include pg/*.pyi +include pg/py.typed +include pgdb/py.typed + +include ext/*.c +include ext/*.h include README.rst include LICENSE.txt include tox.ini - -recursive-include tests *.py +include pyproject.toml include docs/Makefile include docs/make.bat diff --git a/docs/contents/install.rst b/docs/contents/install.rst index fd4f99b5..f447abc3 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -16,11 +16,10 @@ On Windows, you also need to make sure that the directory that contains The current version of PyGreSQL has been tested with Python versions 3.7 to 3.11, and PostgreSQL versions 10 to 15. -PyGreSQL will be installed as three modules, a shared library called -``_pg.so`` (on Linux) or a DLL called ``_pg.pyd`` (on Windows), and two pure -Python wrapper modules called ``pg.py`` and ``pgdb.py``. -All three files will be installed directly into the Python site-packages -directory. To uninstall PyGreSQL, simply remove these three files. +PyGreSQL will be installed as two packages named ``pg`` (for the classic +interface) and ``pgdb`` (for the DB API 2 compliant interface). The former +also contains a shared library called ``_pg.so`` (on Linux) or a DLL called +``_pg.pyd`` (on Windows) and a stub file ``_pg.pyi`` for this library. Installing with Pip diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index b175a2a0..e4a08591 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -616,7 +616,7 @@ getline -- get a line from server socket Get a line from server socket - :returns: the line read + :returns: the line read :rtype: str :raises TypeError: invalid connection :raises TypeError: too many parameters @@ -666,7 +666,7 @@ getlo -- build a large object from given oid :param int oid: OID of the existing large object :returns: object handling the PostgreSQL large object :rtype: :class:`LargeObject` - :raises TypeError: invalid connection, bad parameter type, or too many parameters + :raises TypeError: invalid connection, bad parameter type, or too many parameters :raises ValueError: bad OID value (0 is invalid_oid) This method allows reusing a previously created large object through the diff --git a/docs/contents/pg/large_objects.rst b/docs/contents/pg/large_objects.rst index a1d9818d..037b2128 100644 --- a/docs/contents/pg/large_objects.rst +++ b/docs/contents/pg/large_objects.rst @@ -75,9 +75,9 @@ current position. .. method:: LargeObject.write(string) - Read data to large object + Write data to large object - :param bytes string: string buffer to be written + :param bytes data: buffer of bytes to be written :rtype: None :raises TypeError: invalid connection, bad parameter type, or too many parameters :raises IOError: object is not opened, or write error diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index 2dc26d5f..acf75f93 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -349,8 +349,7 @@ get/set_decimal -- decimal type to be used for numeric values :rtype: class This function returns the Python class that is used by PyGreSQL to hold -PostgreSQL numeric values. The default class is :class:`decimal.Decimal` -if available, otherwise the :class:`float` type is used. +PostgreSQL numeric values. The default class is :class:`decimal.Decimal`. .. function:: set_decimal(cls) @@ -360,8 +359,7 @@ if available, otherwise the :class:`float` type is used. This function can be used to specify the Python class that shall be used by PyGreSQL to hold PostgreSQL numeric values. -The default class is :class:`decimal.Decimal` if available, -otherwise the :class:`float` type is used. +The default class is :class:`decimal.Decimal`. get/set_decimal_point -- decimal mark used for monetary values -------------------------------------------------------------- @@ -639,7 +637,7 @@ are not supported by default in PostgreSQL. :param str string: the string with the text representation of the array :param cast: a typecast function for the elements of the array :type cast: callable or None - :param delim: delimiter character between adjacent elements + :param bytes delim: delimiter character between adjacent elements :type str: byte string with a single character :returns: a list representing the PostgreSQL array in Python :rtype: list @@ -667,7 +665,7 @@ then a comma will be used by default. :param str string: the string with the text representation of the record :param cast: typecast function(s) for the elements of the record :type cast: callable, list or tuple of callables, or None - :param delim: delimiter character between adjacent elements + :param bytes delim: delimiter character between adjacent elements :type str: byte string with a single character :returns: a tuple representing the PostgreSQL record in Python :rtype: tuple diff --git a/docs/download/files.rst b/docs/download/files.rst index ec581bf0..f5e7a523 100644 --- a/docs/download/files.rst +++ b/docs/download/files.rst @@ -3,26 +3,13 @@ Distribution files ============== = -pgmodule.c the main source file for the C extension module (_pg) -pgconn.c the connection object -pginternal.c internal functions -pglarge.c large object support -pgnotice.c the notice object -pgquery.c the query object -pgsource.c the source object +pg/ the "classic" PyGreSQL module -pgtypes.h PostgreSQL type definitions +pgdb/ a DB-SIG DB-API 2.0 compliant API wrapper for PyGreSQL -pg.py the "classic" PyGreSQL module -pgdb.py a DB-SIG DB-API 2.0 compliant API wrapper for PyGreSQL +ext/ the source files for the C extension -setup.py the Python setup script - - To install PyGreSQL, you can run "python setup.py install". - -setup.cfg the Python setup configuration - -docs/ documentation directory +docs/ the documentation directory The documentation has been created with Sphinx. All text files are in ReST format; a HTML version of @@ -30,4 +17,12 @@ docs/ documentation directory tests/ a suite of unit tests for PyGreSQL +pyproject.toml contains project metadata and the build system requirements + +setup.py the Python setup script used for building the C extension + +LICENSE.text contains the license information for PyGreSQL + +README.rst a summary of the PyGreSQL project + ============== = diff --git a/pgconn.c b/ext/pgconn.c similarity index 100% rename from pgconn.c rename to ext/pgconn.c diff --git a/pginternal.c b/ext/pginternal.c similarity index 100% rename from pginternal.c rename to ext/pginternal.c diff --git a/pglarge.c b/ext/pglarge.c similarity index 100% rename from pglarge.c rename to ext/pglarge.c diff --git a/pgmodule.c b/ext/pgmodule.c similarity index 99% rename from pgmodule.c rename to ext/pgmodule.c index 64e769f6..546c5cc5 100644 --- a/pgmodule.c +++ b/ext/pgmodule.c @@ -180,7 +180,7 @@ typedef struct { /* Connect to a database. */ static char pg_connect__doc__[] = - "connect(dbname, host, port, opt, user, passwd, wait) -- connect to a " + "connect(dbname, host, port, opt, user, passwd, nowait) -- connect to a " "PostgreSQL database\n\n" "The connection uses the specified parameters (optional, keywords " "aware).\n"; diff --git a/pgnotice.c b/ext/pgnotice.c similarity index 100% rename from pgnotice.c rename to ext/pgnotice.c diff --git a/pgquery.c b/ext/pgquery.c similarity index 100% rename from pgquery.c rename to ext/pgquery.c diff --git a/pgsource.c b/ext/pgsource.c similarity index 99% rename from pgsource.c rename to ext/pgsource.c index 9bc6bb4a..42510b30 100644 --- a/pgsource.c +++ b/ext/pgsource.c @@ -119,8 +119,8 @@ source_setattr(sourceObject *self, char *name, PyObject *v) /* Close object. */ static char source_close__doc__[] = - "close() -- close query object without deleting it\n\n" - "All instances of the query object can no longer be used after this " + "close() -- close source object without deleting it\n\n" + "All instances of the source object can no longer be used after this " "call.\n"; static PyObject * diff --git a/pgtypes.h b/ext/pgtypes.h similarity index 100% rename from pgtypes.h rename to ext/pgtypes.h diff --git a/pg.py b/pg/__init__.py similarity index 97% rename from pg.py rename to pg/__init__.py index c7a34e1b..0740db20 100644 --- a/pg.py +++ b/pg/__init__.py @@ -51,7 +51,7 @@ from uuid import UUID try: - from _pg import version + from ._pg import version except ImportError as e: # noqa: F841 import os libpq = 'libpq.' @@ -66,7 +66,7 @@ for path in paths: with add_dll_dir(os.path.abspath(path)): try: - from _pg import version # type: ignore + from ._pg import version except ImportError: pass else: @@ -85,13 +85,17 @@ del version # import objects from extension module -from _pg import ( +from ._pg import ( INV_READ, INV_WRITE, POLLING_FAILED, POLLING_OK, POLLING_READING, POLLING_WRITING, + RESULT_DDL, + RESULT_DML, + RESULT_DQL, + RESULT_EMPTY, SEEK_CUR, SEEK_END, SEEK_SET, @@ -167,6 +171,7 @@ 'Connection', 'Query', 'INV_READ', 'INV_WRITE', 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', + 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', 'TRANS_ACTIVE', 'TRANS_IDLE', 'TRANS_INERROR', 'TRANS_INTRANS', 'TRANS_UNKNOWN', @@ -186,6 +191,8 @@ # Auxiliary classes and functions that are independent of a DB connection: +SomeNamedTuple = Any # alias for accessing arbitrary named tuples + def get_args(func: Callable) -> list: return list(signature(func).parameters) @@ -1188,7 +1195,7 @@ class DbType(str): category: str delim: str relid: int - + _get_attnames: Callable[[DbType], AttrDict] @property @@ -1336,14 +1343,14 @@ def _dictiter(q: Query) -> Generator[dict[str, Any], None, None]: yield dict(zip(fields, r)) -def _namediter(q: Query) -> Generator[NamedTuple, None, None]: +def _namediter(q: Query) -> Generator[SomeNamedTuple, None, None]: """Get query result as an iterator of named tuples.""" row = _row_factory(q.listfields()) for r in q: yield row(r) -def _namednext(q: Query) -> NamedTuple: +def _namednext(q: Query) -> SomeNamedTuple: """Get next row from query result as a named tuple.""" return _row_factory(q.listfields())(next(q)) @@ -1378,23 +1385,29 @@ def __iter__(self) -> Iterator[Any]: # Error messages -E = TypeVar('E', bound=DatabaseError) +E = TypeVar('E', bound=Error) -def _db_error(msg: str, cls:type[E] = DatabaseError) -> type[E]: - """Return DatabaseError with empty sqlstate attribute.""" +def _error(msg: str, cls: type[E]) -> E: + """Return specified error object with empty sqlstate attribute.""" error = cls(msg) - error.sqlstate = None + if isinstance(error, DatabaseError): + error.sqlstate = None return error +def _db_error(msg: str) -> DatabaseError: + """Return DatabaseError.""" + return _error(msg, DatabaseError) + + def _int_error(msg: str) -> InternalError: """Return InternalError.""" - return _db_error(msg, InternalError) + return _error(msg, InternalError) def _prg_error(msg: str) -> ProgrammingError: """Return ProgrammingError.""" - return _db_error(msg, ProgrammingError) + return _error(msg, ProgrammingError) # Initialize the C module @@ -1468,7 +1481,7 @@ def unlisten(self) -> None: self.listening = False def notify(self, db: DB | None = None, stop: bool = False, - payload: str | None = None) -> None: + payload: str | None = None) -> Query | None: """Generate a notification. Optionally, you can pass a payload with the notification. @@ -1480,16 +1493,17 @@ def notify(self, db: DB | None = None, stop: bool = False, must pass a different database connection since PyGreSQL database connections are not thread-safe. """ - if self.listening: + if not self.listening: + return None + if not db: + db = self.db if not db: - db = self.db - if not db: - return - event = self.stop_event if stop else self.event - cmd = f'notify "{event}"' - if payload: - cmd += f", '{payload}'" - return db.query(cmd) + return None + event = self.stop_event if stop else self.event + cmd = f'notify "{event}"' + if payload: + cmd += f", '{payload}'" + return db.query(cmd) def __call__(self) -> None: """Invoke the notification handler. @@ -1545,6 +1559,7 @@ class DB: """Wrapper class for the _pg connection type.""" db: Connection | None = None # invalid fallback for underlying connection + _db_args: Any # either the connectoin args or the underlying connection def __init__(self, *args: Any, **kw: Any) -> None: """Create a new connection. @@ -1730,7 +1745,7 @@ def reset(self) -> None: All derived queries and large objects derived from this connection will not be usable after this call. - """ + """ self._valid_db.reset() def reopen(self) -> None: @@ -1741,7 +1756,8 @@ def reopen(self) -> None: """ # There is no such shared library function. if self._closeable: - db = connect(*self._db_args[0], **self._db_args[1]) + args, kw = self._db_args + db = connect(*args, **kw) if self.db: self.db.set_cast_hook(None) self.db.close() @@ -1750,7 +1766,7 @@ def reopen(self) -> None: else: self.db = self._db_args - def begin(self, mode: str | None = None) -> None: + def begin(self, mode: str | None = None) -> Query: """Begin a transaction.""" qstr = 'BEGIN' if mode: @@ -1759,13 +1775,13 @@ def begin(self, mode: str | None = None) -> None: start = begin - def commit(self) -> None: + def commit(self) -> Query: """Commit the current transaction.""" return self.query('COMMIT') end = commit - def rollback(self, name: str | None = None) -> None: + def rollback(self, name: str | None = None) -> Query: """Roll back the current transaction.""" qstr = 'ROLLBACK' if name: @@ -1774,11 +1790,11 @@ def rollback(self, name: str | None = None) -> None: abort = rollback - def savepoint(self, name: str) -> None: + def savepoint(self, name: str) -> Query: """Define a new savepoint within the current transaction.""" return self.query('SAVEPOINT ' + name) - def release(self, name: str) -> None: + def release(self, name: str) -> Query: """Destroy a previously defined savepoint.""" return self.query('RELEASE ' + name) @@ -1983,7 +1999,7 @@ def query_prepared(self, name: str, *args: Any) -> Query: self._do_debug('EXECUTE', name) return db.query_prepared(name) - def prepare(self, name: str, command: str) -> Query: + def prepare(self, name: str, command: str) -> None: """Create a prepared SQL statement. This creates a prepared statement for the given command with the @@ -1999,7 +2015,7 @@ def prepare(self, name: str, command: str) -> Query: if name is None: name = '' self._do_debug('prepare', name, command) - return self._valid_db.prepare(name, command) + self._valid_db.prepare(name, command) def describe_prepared(self, name: str | None = None) -> Query: """Describe a prepared SQL statement. @@ -2057,17 +2073,17 @@ def pkey(self, table: str, composite: bool = False, flush: bool = False " {}::pg_catalog.regclass" " AND i.indisprimary ORDER BY a.attnum").format( _quote_if_unqualified('$1', table)) - pkey = self._valid_db.query(cmd, (table,)).getresult() - if not pkey: + res = self._valid_db.query(cmd, (table,)).getresult() + if not res: raise KeyError(f'Table {table} has no primary key') from e # we want to use the order defined in the primary key index here, # not the order as defined by the columns in the table - if len(pkey) > 1: - indkey = pkey[0][2] + if len(res) > 1: + indkey = res[0][2] pkey = tuple(row[0] for row in sorted( - pkey, key=lambda row: indkey.index(row[1]))) + res, key=lambda row: indkey.index(row[1]))) else: - pkey = pkey[0][0] + pkey = res[0][0] pkeys[table] = pkey # cache it if composite and not isinstance(pkey, tuple): pkey = (pkey,) @@ -2075,7 +2091,7 @@ def pkey(self, table: str, composite: bool = False, flush: bool = False def pkeys(self, table: str) -> tuple[str, ...]: """Get the primary key of a table as a tuple. - + Same as pkey() with 'composite' set to True. """ return self.pkey(table, True) # type: ignore @@ -2146,9 +2162,9 @@ def get_attnames(self, table: str, with_oid: bool=True, flush: bool=False cmd = f"({cmd} OR a.attname OPERATOR(pg_catalog.=) 'oid')" cmd = self._query_attnames.format( _quote_if_unqualified('$1', table), cmd) - names = self._valid_db.query(cmd, (table,)).getresult() + res = self._valid_db.query(cmd, (table,)).getresult() types = self.dbtypes - names = AttrDict((name[0], types.add(*name[1:])) for name in names) + names = AttrDict((name[0], types.add(*name[1:])) for name in res) attnames[table] = names # cache it return names @@ -2172,8 +2188,8 @@ def get_generated(self, table: str, flush: bool = False) -> frozenset[str]: cmd = f"{cmd} AND {self._query_generated}" cmd = self._query_attnames.format( _quote_if_unqualified('$1', table), cmd) - names = self._valid_db.query(cmd, (table,)).getresult() - names = frozenset(name[0] for name in names) + res = self._valid_db.query(cmd, (table,)).getresult() + names = frozenset(name[0] for name in res) generated[table] = names # cache it return names @@ -2578,7 +2594,7 @@ def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any cmd = f'DELETE FROM {t} WHERE {where}' # noqa: S608 self._do_debug(cmd, params) res = self._valid_db.query(cmd, params) - return int(res) + return int(res) # type: ignore def truncate(self, table: str | list[str] | tuple[str, ...] | set[str] | frozenset[str], restart: bool = False, @@ -2660,7 +2676,7 @@ def get_as_list( The parameter 'where' can restrict the query to only return a subset of the table rows. It can be a string, list or a tuple of SQL expressions that all need to be fulfilled. - + The parameter 'order' specifies the ordering of the rows. It can also be a string, list or a tuple. If no ordering is specified, the result will be ordered by the primary key(s) or all columns if @@ -2806,7 +2822,7 @@ def get_row(row : tuple) -> tuple: if key_tuple: keys = _namediter(_MemoryQuery(keys, keynames)) # type: ignore if row_is_tuple: - fields = [f for f in fields if f not in keyset] + fields = tuple(f for f in fields if f not in keyset) rows = _namediter(_MemoryQuery(rows, fields)) # type: ignore # noinspection PyArgumentList return dict(zip(keys, rows)) @@ -2824,6 +2840,6 @@ def notification_handler(self, event: str, callback: Callable, # if run as script, print some information if __name__ == '__main__': - print('PyGreSQL version' + version) - print('') + print('PyGreSQL version', version) + print() print(__doc__) diff --git a/pg/_pg.pyi b/pg/_pg.pyi new file mode 100644 index 00000000..70f6e37e --- /dev/null +++ b/pg/_pg.pyi @@ -0,0 +1,635 @@ +"""Type hints for the PyGreSQL C extension.""" + +from __future__ import annotations + +from typing import Any, Callable, Iterable, Sequence, TypeVar + +AnyStr = TypeVar('AnyStr', str, bytes, str | bytes) +SomeNamedTuple = Any # alias for accessing arbitrary named tuples + +version: str +__version__: str + +RESULT_EMPTY: int +RESULT_DML: int +RESULT_DDL: int +RESULT_DQL: int + +TRANS_IDLE: int +TRANS_ACTIVE: int +TRANS_INTRANS: int +TRANS_INERROR: int +TRANS_UNKNOWN: int + +POLLING_OK: int +POLLING_FAILED: int +POLLING_READING: int +POLLING_WRITING: int + +INV_READ: int +INV_WRITE: int + +SEEK_SET: int +SEEK_CUR: int +SEEK_END: int + + +class Error(Exception): + """Exception that is the base class of all other error exceptions.""" + + +class Warning(Exception): # noqa: N818 + """Exception raised for important warnings.""" + + +class InterfaceError(Error): + """Exception raised for errors related to the database interface.""" + + +class DatabaseError(Error): + """Exception raised for errors that are related to the database.""" + + sqlstate: str | None + + +class InternalError(DatabaseError): + """Exception raised when the database encounters an internal error.""" + + +class OperationalError(DatabaseError): + """Exception raised for errors related to the operation of the database.""" + + +class ProgrammingError(DatabaseError): + """Exception raised for programming errors.""" + + +class IntegrityError(DatabaseError): + """Exception raised when the relational integrity is affected.""" + + +class DataError(DatabaseError): + """Exception raised for errors due to problems with the processed data.""" + + +class NotSupportedError(DatabaseError): + """Exception raised when a method or database API is not supported.""" + + +class InvalidResultError(DataError): + """Exception when a database operation produced an invalid result.""" + + +class NoResultError(InvalidResultError): + """Exception when a database operation did not produce any result.""" + + +class MultipleResultsError(InvalidResultError): + """Exception when a database operation produced multiple results.""" + + +class Source: + """Source object.""" + + arraysize: int + resulttype: int + ntuples: int + nfields: int + + def execute(self, sql: str) -> int | None: + """Execute a SQL statement.""" + ... + + def fetch(self, num: int) -> list[tuple]: + """Return the next num rows from the last result in a list.""" + ... + + def listinfo(self) -> tuple[tuple[int, str, int, int, int], ...]: + """Get information for all fields.""" + ... + + def oidstatus(self) -> int | None: + """Return oid of last inserted row (if available).""" + ... + + def putdata(self, buffer: str | bytes | BaseException | None + ) -> int | None: + """Send data to server during copy from stdin.""" + ... + + def getdata(self, decode: bool | None = None) -> str | bytes | int: + """Receive data to server during copy to stdout.""" + ... + + def close(self) -> None: + """Close query object without deleting it.""" + ... + + +class LargeObject: + """Large object.""" + + oid: int + pgcnx: Connection + error: str + + def open(self, mode: int) -> None: + """Open a large object. + + The valid values for 'mode' parameter are defined as the module level + constants INV_READ and INV_WRITE. + """ + ... + + def close(self) -> None: + """Close a large object.""" + ... + + def read(self, size: int) -> bytes: + """Read data from large object.""" + ... + + def write(self, data: bytes) -> None: + """Write data to large object.""" + ... + + def seek(self, offset: int, whence: int) -> int: + """Change current position in large object. + + The valid values for the 'whence' parameter are defined as the + module level constants SEEK_SET, SEEK_CUR and SEEK_END. + """ + ... + + def unlink(self) -> None: + """Delete large object.""" + ... + + def size(self) -> int: + """Return the large object size.""" + ... + + def export(self, filename: str) -> None: + """Export a large object to a file.""" + ... + + +class Connection: + """Connection object. + + This object handles a connection to a PostgreSQL database. + It embeds and hides all the parameters that define this connection, + thus just leaving really significant parameters in function calls. + """ + + host: str + port: int + db: str + options: str + error: str + status: int + user : str + protocol_version: int + server_version: int + socket: int + backend_pid: int + ssl_in_use: bool + ssl_attributes: dict[str, str | None] + + def source(self) -> Source: + """Create a new source object for this connection.""" + ... + + def query(self, cmd: str, args: Sequence | None = None) -> Query: + """Create a new query object for this connection. + + Note that if the command is something other than DQL, this method + can return an int, str or None instead of a Query. + """ + ... + + def send_query(self, cmd: str, args: Sequence | None = None) -> Query: + """Create a new asynchronous query object for this connection.""" + ... + + def query_prepared(self, name: str, args: Sequence | None = None) -> Query: + """Execute a prepared statement.""" + ... + + def prepare(self, name: str, cmd: str) -> None: + """Create a prepared statement.""" + ... + + def describe_prepared(self, name: str) -> Query: + """Describe a prepared statement.""" + ... + + def poll(self) -> int: + """Complete an asynchronous connection and get its state.""" + ... + + def reset(self) -> None: + """Reset the connection.""" + ... + + def cancel(self) -> None: + """Abandon processing of current SQL command.""" + ... + + def close(self) -> None: + """Close the database connection.""" + ... + + def fileno(self) -> int: + """Get the socket used to connect to the database.""" + ... + + def get_cast_hook(self) -> Callable | None: + """Get the function that handles all external typecasting.""" + ... + + def set_cast_hook(self, hook: Callable | None) -> None: + """Set a function that will handle all external typecasting.""" + ... + + def get_notice_receiver(self) -> Callable | None: + """Get the current notice receiver.""" + ... + + def set_notice_receiver(self, receiver: Callable | None) -> None: + """Set a custom notice receiver.""" + ... + + def getnotify(self) -> tuple[str, int, str] | None: + """Get the last notify from the server.""" + ... + + def inserttable(self, table: str, values: Sequence[list|tuple], + columns: list[str] | tuple[str, ...] | None = None) -> int: + """Insert a Python iterable into a database table.""" + ... + + def transaction(self) -> int: + """Get the current in-transaction status of the server. + + The status returned by this method can be TRANS_IDLE (currently idle), + TRANS_ACTIVE (a command is in progress), TRANS_INTRANS (idle, in a + valid transaction block), or TRANS_INERROR (idle, in a failed + transaction block). TRANS_UNKNOWN is reported if the connection is + bad. The status TRANS_ACTIVE is reported only when a query has been + sent to the server and not yet completed. + """ + ... + + def parameter(self, name: str) -> str | None: + """Look up a current parameter setting of the server.""" + ... + + def date_format(self) -> str: + """Look up the date format currently being used by the database.""" + ... + + def escape_literal(self, s: AnyStr) -> AnyStr: + """Escape a literal constant for use within SQL.""" + ... + + def escape_identifier(self, s: AnyStr) -> AnyStr: + """Escape an identifier for use within SQL.""" + ... + + def escape_string(self, s: AnyStr) -> AnyStr: + """Escape a string for use within SQL.""" + ... + + def escape_bytea(self, s: AnyStr) -> AnyStr: + """Escape binary data for use within SQL as type 'bytea'.""" + ... + + def putline(self, line: str) -> None: + """Write a line to the server socket.""" + ... + + def getline(self) -> str: + """Get a line from server socket.""" + ... + + def endcopy(self) -> None: + """Synchronize client and server.""" + ... + + def set_non_blocking(self, nb: bool) -> None: + """Set the non-blocking mode of the connection.""" + ... + + def is_non_blocking(self) -> bool: + """Get the non-blocking mode of the connection.""" + ... + + def locreate(self, mode: int) -> LargeObject: + """Create a large object in the database. + + The valid values for 'mode' parameter are defined as the module level + constants INV_READ and INV_WRITE. + """ + ... + + def getlo(self, oid: int) -> LargeObject: + """Build a large object from given oid.""" + ... + + def loimport(self, filename: str) -> LargeObject: + """Import a file to a large object.""" + ... + + +class Query: + """Query object. + + The Query object returned by Connection.query and DB.query can be used + as an iterable returning rows as tuples. You can also directly access + row tuples using their index, and get the number of rows with the + len() function. The Query class also provides the several methods + for accessing the results of the query. + """ + + def __len__(self) -> int: + ... + + def __getitem__(self, key: int) -> object: + ... + + def __iter__(self) -> Query: + ... + + def __next__(self) -> tuple: + ... + + def getresult(self) -> list[tuple]: + """Get query values as list of tuples.""" + ... + + def dictresult(self) -> list[dict[str, object]]: + """Get query values as list of dictionaries.""" + ... + + def dictiter(self) -> Iterable[dict[str, object]]: + """Get query values as iterable of dictionaries.""" + ... + + def namedresult(self) -> list[SomeNamedTuple]: + """Get query values as list of named tuples.""" + ... + + def namediter(self) -> Iterable[SomeNamedTuple]: + """Get query values as iterable of named tuples.""" + ... + + def one(self) -> tuple | None: + """Get one row from the result of a query as a tuple.""" + ... + + def single(self) -> tuple: + """Get single row from the result of a query as a tuple.""" + ... + + def onedict(self) -> dict[str, object] | None: + """Get one row from the result of a query as a dictionary.""" + ... + + def singledict(self) -> dict[str, object]: + """Get single row from the result of a query as a dictionary.""" + ... + + def onenamed(self) -> SomeNamedTuple | None: + """Get one row from the result of a query as named tuple.""" + ... + + def singlenamed(self) -> SomeNamedTuple: + """Get single row from the result of a query as named tuple.""" + ... + + def scalarresult(self) -> list: + """Get first fields from query result as list of scalar values.""" + + def scalariter(self) -> Iterable: + """Get first fields from query result as iterable of scalar values.""" + ... + + def onescalar(self) -> object | None: + """Get one row from the result of a query as scalar value.""" + ... + + def singlescalar(self) -> object: + """Get single row from the result of a query as scalar value.""" + ... + + def fieldname(self, num: int) -> str: + """Get field name from its number.""" + ... + + def fieldnum(self, name: str) -> int: + """Get field number from its name.""" + ... + + def listfields(self) -> tuple[str, ...]: + """List field names of query result.""" + ... + + def fieldinfo(self, column: int | str | None) -> tuple[str, int, int, int]: + """Get information on one or all fields of the query. + + The four-tuples contain the following information: + The field name, the internal OID number of the field type, + the size in bytes of the column or a negative value if it is + of variable size, and a type-specific modifier value. + """ + ... + + def memsize(self) -> int: + """Return number of bytes allocated by query result.""" + ... + + +def connect(dbname: str | None = None, + host: str | None = None, + port: int | None = None, + opt: str | None = None, + user: str | None = None, + passwd: str | None = None, + nowait: int | None = None) -> Connection: + """Connect to a PostgreSQL database.""" + ... + + +def cast_array(s: str, cast: Callable | None = None, + delim: bytes | None = None) -> list: + """Cast a string representing a PostgreSQL array to a Python list.""" + ... + + +def cast_record(s: str, + cast: Callable | list[Callable | None] | + tuple[Callable | None, ...] | None = None, + delim: bytes | None = None) -> tuple: + """Cast a string representing a PostgreSQL record to a Python tuple.""" + ... + + +def cast_hstore(s: str) -> dict[str, str | None]: + """Cast a string as a hstore.""" + ... + + +def escape_bytea(s: AnyStr) -> AnyStr: + """Escape binary data for use within SQL as type 'bytea'.""" + ... + + +def unescape_bytea(s: AnyStr) -> bytes: + """Unescape 'bytea' data that has been retrieved as text.""" + ... + + +def escape_string(s: AnyStr) -> AnyStr: + """Escape a string for use within SQL.""" + ... + + +def get_pqlib_version() -> int: + """Get the version of libpq that is being used by PyGreSQL.""" + ... + + +def get_array() -> bool: + """Check whether arrays are returned as list objects.""" + ... + + +def set_array(on: bool) -> None: + """Set whether arrays are returned as list objects.""" + ... + + +def get_bool() -> bool: + """Check whether boolean values are returned as bool objects.""" + ... + + +def set_bool(on: bool | int) -> None: + """Set whether boolean values are returned as bool objects.""" + ... + + +def get_bytea_escaped() -> bool: + """Check whether 'bytea' values are returned as escaped strings.""" + ... + + +def set_bytea_escaped(on: bool | int) -> None: + """Set whether 'bytea' values are returned as escaped strings.""" + ... + + +def get_datestyle() -> str | None: + """Get the assumed date style for typecasting.""" + ... + + +def set_datestyle(datestyle: str | None) -> None: + """Set a fixed date style that shall be assumed when typecasting.""" + ... + + +def get_decimal() -> type: + """Get the decimal type to be used for numeric values.""" + ... + + +def set_decimal(cls: type) -> None: + """Set a fixed date style that shall be assumed when typecasting.""" + ... + + +def get_decimal_point() -> str | None: + """Get the decimal mark used for monetary values.""" + ... + + +def set_decimal_point(mark: str | None) -> None: + """Specify which decimal mark is used for interpreting monetary values.""" + ... + + +def get_jsondecode() -> Callable[[str], object] | None: + """Get the function that deserializes JSON formatted strings.""" + ... + + +def set_jsondecode(decode: Callable[[str], object] | None) -> None: + """Set a function that will deserialize JSON formatted strings.""" + ... + + +def get_defbase() -> str | None: + """Get the default database name.""" + ... + + +def set_defbase(base: str | None) -> None: + """Set the default database name.""" + ... + + +def get_defhost() -> str | None: + """Get the default host.""" + ... + + +def set_defhost(host: str | None) -> None: + """Set the default host.""" + ... + + +def get_defport() -> int | None: + """Get the default host.""" + ... + + +def set_defport(port: int | None) -> None: + """Set the default port.""" + ... + + +def get_defopt() -> str | None: + """Get the default connection options.""" + ... + + +def set_defopt(opt: str | None) -> None: + """Set the default connection options.""" + ... + + +def get_defuser() -> str | None: + """Get the default database user.""" + ... + + +def set_defuser(user: str | None) -> None: + """Set the default database user.""" + ... + + +def get_defpasswd() -> str | None: + """Get the default database password.""" + ... + + +def set_defpasswd(passwd: str | None) -> None: + """Set the default database password.""" + ... + + +def set_query_helpers(*helpers: Callable) -> None: + """Set internal query helper functions.""" + ... diff --git a/pg/py.typed b/pg/py.typed new file mode 100644 index 00000000..ea6e1ace --- /dev/null +++ b/pg/py.typed @@ -0,0 +1,4 @@ +# Marker file for PEP 561. + +# The pg package use inline types, +# except for the _pg extension module which uses a stub file. diff --git a/pgdb.py b/pgdb/__init__.py similarity index 97% rename from pgdb.py rename to pgdb/__init__.py index 332ca3d0..74ad38e5 100644 --- a/pgdb.py +++ b/pgdb/__init__.py @@ -90,42 +90,8 @@ ) from uuid import UUID as Uuid # noqa: N811 -try: - from _pg import version -except ImportError as e: # noqa: F841 - import os - libpq = 'libpq.' - if os.name == 'nt': - libpq += 'dll' - import sys - paths = [path for path in os.environ["PATH"].split(os.pathsep) - if os.path.exists(os.path.join(path, libpq))] - if sys.version_info >= (3, 8): - # see https://docs.python.org/3/whatsnew/3.8.html#ctypes - add_dll_dir = os.add_dll_directory # type: ignore - for path in paths: - with add_dll_dir(os.path.abspath(path)): - try: - from _pg import version # type: ignore - except ImportError: - pass - else: - del version - e = None # type: ignore - break - if paths: - libpq = 'compatible ' + libpq - else: - libpq += 'so' - if e: - raise ImportError( - "Cannot import shared library for PyGreSQL,\n" - f"probably because no {libpq} is installed.\n{e}") from e -else: - del version - # import objects from extension module -from _pg import ( +from pg import ( RESULT_DQL, DatabaseError, DataError, @@ -143,10 +109,10 @@ unescape_bytea, version, ) -from _pg import ( +from pg import ( Connection as Cnx, # base connection ) -from _pg import ( +from pg import ( connect as get_cnx, # get base connection ) @@ -694,10 +660,9 @@ def __missing__(self, key: int | str) -> TypeCode: res = self._src.fetch(1) if not res: raise KeyError(f'Type {key} could not be found') - res = res[0] + r = res[0] type_code = TypeCode.create( - int(res[0]), res[1], int(res[2]), - res[3], res[4], res[5], int(res[6])) + int(r[0]), r[1], int(r[2]), r[3], r[4], r[5], int(r[6])) # noinspection PyUnresolvedReferences self[type_code.oid] = self[str(type_code)] = type_code return type_code @@ -782,19 +747,30 @@ def __getitem__(self, key: str) -> str: # *** Error Messages *** -E = TypeVar('E', bound=DatabaseError) +E = TypeVar('E', bound=Error) -def _db_error(msg: str, cls:type[E] = DatabaseError) -> type[E]: - """Return DatabaseError with empty sqlstate attribute.""" +def _error(msg: str, cls: type[E]) -> E: + """Return specified error object with empty sqlstate attribute.""" error = cls(msg) - error.sqlstate = None + if isinstance(error, DatabaseError): + error.sqlstate = None return error +def _db_error(msg: str) -> DatabaseError: + """Return DatabaseError.""" + return _error(msg, DatabaseError) + + +def _if_error(msg: str) -> InterfaceError: + """Return InterfaceError.""" + return _error(msg, InterfaceError) + + def _op_error(msg: str) -> OperationalError: """Return OperationalError.""" - return _db_error(msg, OperationalError) + return _error(msg, OperationalError) # *** Row Tuples *** @@ -835,8 +811,8 @@ def __init__(self, connection: Connection) -> None: cnx = connection._cnx if not cnx: raise _op_error("Connection has been closed") - self._cnx = cnx - self.type_cache = connection.type_cache + self._cnx: Cnx = cnx + self.type_cache: TypeCache = connection.type_cache self._src = self._cnx.source() # the official attribute for describing the result columns self._description: list[CursorDescription] | bool | None = None @@ -845,9 +821,9 @@ def __init__(self, connection: Connection) -> None: self.row_factory = None # type: ignore else: self.build_row_factory = None # type: ignore - self.rowcount = -1 - self.arraysize = 1 - self.lastrowid = None + self.rowcount: int | None = -1 + self.arraysize: int = 1 + self.lastrowid: int | None = None def __iter__(self) -> Cursor: """Make cursor compatible to the iteration protocol.""" @@ -1044,8 +1020,7 @@ def executemany(self, operation: str, raise # database provides error message except Error as err: # noinspection PyTypeChecker - raise _db_error( - f"Error in '{sql}': '{err}'", InterfaceError) from err + raise _if_error(f"Error in '{sql}': '{err}'") from err except Exception as err: raise _op_error(f"Internal error in '{sql}': {err}") from err # then initialize result raw count and description @@ -1264,7 +1239,8 @@ def chunks() -> Generator: # the following call will re-raise the error putdata(error) else: - self.rowcount = putdata(None) + rowcount = putdata(None) + self.rowcount = -1 if rowcount is None else rowcount # return the cursor object, so you can chain operations return self @@ -1459,7 +1435,7 @@ class Connection: def __init__(self, cnx: Cnx) -> None: """Create a database connection object.""" - self._cnx = cnx # connection + self._cnx: Cnx | None = cnx # connection self._tnx = False # transaction state self.type_cache = TypeCache(cnx) self.cursor_type = Cursor @@ -1509,7 +1485,7 @@ def close(self) -> None: with suppress(DatabaseError): self.rollback() self._cnx.close() - self._cnx = None + self._cnx = None @property def closed(self) -> bool: @@ -1857,5 +1833,5 @@ def __str__(self) -> str: if __name__ == '__main__': print('PyGreSQL version', version) - print('') + print() print(__doc__) diff --git a/pgdb/py.typed b/pgdb/py.typed new file mode 100644 index 00000000..ead52d46 --- /dev/null +++ b/pgdb/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. The pgdb package uses inline types. diff --git a/pyproject.toml b/pyproject.toml index 1016b433..e289b38f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,9 +89,13 @@ module = [ disallow_untyped_defs = false [tool.setuptools] -py-modules = ["pg", "pgdb"] +packages = ["pg", "pgdb"] license-files = ["LICENSE.txt"] +[tool.setuptools.package-data] +pg = ["pg.typed"] +pgdb = ["pg.typed"] + [build-system] requires = ["setuptools>=68", "wheel>=0.41"] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 3ad3a906..4fd39c56 100755 --- a/setup.py +++ b/setup.py @@ -33,9 +33,6 @@ # classic interface, and "pgdb" for the modern DB-API 2.0 interface. # These two top-level Python modules share the same C extension "_pg". -py_modules = ['pg', 'pgdb'] -c_sources = ['pgmodule.c'] - def pg_config(s): """Retrieve information about installed version of PostgreSQL.""" f = os.popen(f'pg_config --{s}') # noqa: S605 @@ -127,27 +124,27 @@ def finalize_options(self): setup( - name="PyGreSQL", + name='PyGreSQL', version=version, - description="Python PostgreSQL Interfaces", + description='Python PostgreSQL Interfaces', long_description=long_description, long_description_content_type='text/x-rst', - keywords="pygresql postgresql database api dbapi", + keywords='pygresql postgresql database api dbapi', author="D'Arcy J. M. Cain", author_email="darcy@PyGreSQL.org", - url="https://pygresql.github.io/", - download_url="https://pygresql.github.io/contents/download/", + url='https://pygresql.github.io/', + download_url='https://pygresql.github.io/contents/download/', project_urls={ - "Documentation": "https://pygresql.github.io/contents/", - "Issue Tracker": "https://github.com/PyGreSQL/PyGreSQL/issues/", - "Mailing List": "https://mail.vex.net/mailman/listinfo/pygresql", - "Source Code": "https://github.com/PyGreSQL/PyGreSQL"}, + 'Documentation': 'https://pygresql.github.io/contents/', + 'Issue Tracker': 'https://github.com/PyGreSQL/PyGreSQL/issues/', + 'Mailing List': 'https://mail.vex.net/mailman/listinfo/pygresql', + 'Source Code': 'https://github.com/PyGreSQL/PyGreSQL'}, classifiers=[ - "Development Status :: 6 - Mature", - "Intended Audience :: Developers", - "License :: OSI Approved :: PostgreSQL License", - "Operating System :: OS Independent", - "Programming Language :: C", + 'Development Status :: 6 - Mature', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: PostgreSQL License', + 'Operating System :: OS Independent', + 'Programming Language :: C', 'Programming Language :: Python', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.7', @@ -155,16 +152,17 @@ def finalize_options(self): 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', - "Programming Language :: SQL", - "Topic :: Database", - "Topic :: Database :: Front-Ends", - "Topic :: Software Development :: Libraries :: Python Modules"], - license="PostgreSQL", - py_modules=py_modules, + 'Programming Language :: SQL', + 'Topic :: Database', + 'Topic :: Database :: Front-Ends', + 'Topic :: Software Development :: Libraries :: Python Modules'], + license='PostgreSQL', test_suite='tests.discover', zip_safe=False, + packages=["pg", "pgdb"], + package_data={"pg": ["py.typed"], "pgdb": ["py.typed"]}, ext_modules=[Extension( - '_pg', c_sources, + 'pg._pg', ["ext/pgmodule.c"], include_dirs=include_dirs, library_dirs=library_dirs, define_macros=define_macros, undef_macros=undef_macros, libraries=libraries, extra_compile_args=extra_compile_args)], diff --git a/tests/config.py b/tests/config.py index 0b15f62e..4e27c3ae 100644 --- a/tests/config.py +++ b/tests/config.py @@ -18,13 +18,10 @@ dbname = get('PYGRESQL_DB', get('PGDATABASE', 'test')) dbhost = get('PYGRESQL_HOST', get('PGHOST', 'localhost')) -dbport = get('PYGRESQL_PORT', get('PGPORT', 5432)) +dbport = int(get('PYGRESQL_PORT', get('PGPORT', 5432))) dbuser = get('PYGRESQL_USER', get('PGUSER')) dbpasswd = get('PYGRESQL_PASSWD', get('PGPASSWORD')) -if dbport: - dbport = int(dbport) - try: from .LOCAL_PyGreSQL import * # type: ignore # noqa except (ImportError, ValueError): diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 242fdbb5..d6a742bf 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -19,7 +19,7 @@ from collections.abc import Iterable from contextlib import suppress from decimal import Decimal -from typing import Sequence +from typing import Any, Sequence import pg # the module under test @@ -999,7 +999,7 @@ def test_query_with_bool_params(self, bool_enabled=None): self.assertEqual(query(q, (False,)).getresult(), r_false) self.assertEqual(query(q, (True,)).getresult(), r_true) finally: - if bool_enabled is not None: + if bool_enabled_default is not None: pg.set_bool(bool_enabled_default) def test_query_with_bool_params_not_default(self): @@ -1557,7 +1557,7 @@ def test_single_with_empty_query(self): try: q.single() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) @@ -1577,7 +1577,7 @@ def test_single_with_two_rows(self): try: q.single() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) @@ -1588,7 +1588,7 @@ def test_single_dict_with_empty_query(self): try: q.singledict() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) @@ -1608,7 +1608,7 @@ def test_single_dict_with_two_rows(self): try: q.singledict() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) @@ -1619,7 +1619,7 @@ def test_single_named_with_empty_query(self): try: q.singlenamed() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) @@ -1627,7 +1627,7 @@ def test_single_named_with_empty_query(self): def test_single_named_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") - r = q.singlenamed() + r: Any = q.singlenamed() self.assertEqual(r._fields, ('one', 'two')) self.assertEqual(r.one, 1) self.assertEqual(r.two, 2) @@ -1643,7 +1643,7 @@ def test_single_named_with_two_rows(self): try: q.singlenamed() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) @@ -1654,7 +1654,7 @@ def test_single_scalar_with_empty_query(self): try: q.singlescalar() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) @@ -1674,7 +1674,7 @@ def test_single_scalar_with_two_rows(self): try: q.singlescalar() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) @@ -2685,38 +2685,38 @@ def setUpClass(cls): def test_escape_string(self): self.assertTrue(self.cls_set_up) f = pg.escape_string - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f('plain') - self.assertIsInstance(r, str) - self.assertEqual(r, 'plain') - r = f("das is' käse".encode()) - self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is'' käse".encode()) - r = f("that's cheesy") - self.assertIsInstance(r, str) - self.assertEqual(r, "that''s cheesy") - r = f(r"It's bad to have a \ inside.") - self.assertEqual(r, r"It''s bad to have a \\ inside.") + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + b = f("das is' käse".encode()) + self.assertIsInstance(b, bytes) + self.assertEqual(b, "das is'' käse".encode()) + s = f("that's cheesy") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheesy") + s = f(r"It's bad to have a \ inside.") + self.assertEqual(s, r"It''s bad to have a \\ inside.") def test_escape_bytea(self): self.assertTrue(self.cls_set_up) f = pg.escape_bytea - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f('plain') - self.assertIsInstance(r, str) - self.assertEqual(r, 'plain') - r = f("das is' käse".encode()) - self.assertIsInstance(r, bytes) - self.assertEqual(r, b"das is'' k\\\\303\\\\244se") - r = f("that's cheesy") - self.assertIsInstance(r, str) - self.assertEqual(r, "that''s cheesy") - r = f(b'O\x00ps\xff!') - self.assertEqual(r, b'O\\\\000ps\\\\377!') + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + b = f("das is' käse".encode()) + self.assertIsInstance(b, bytes) + self.assertEqual(b, b"das is'' k\\\\303\\\\244se") + s = f("that's cheesy") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheesy") + b = f(b'O\x00ps\xff!') + self.assertEqual(b, b'O\\\\000ps\\\\377!') if __name__ == '__main__': diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 71438f71..74d6df8e 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -3348,24 +3348,26 @@ def test_insert_update_get_bytea(self): def test_upsert_bytea(self): self.create_table('bytea_test', 'n smallint primary key, data bytea') s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n" - r = dict(n=7, data=s) - r = self.db.upsert('bytea_test', r) - self.assertIsInstance(r, dict) - self.assertIn('n', r) - self.assertEqual(r['n'], 7) - self.assertIn('data', r) + d = dict(n=7, data=s) + d = self.db.upsert('bytea_test', d) + self.assertIsInstance(d, dict) + self.assertIn('n', d) + self.assertEqual(d['n'], 7) + self.assertIn('data', d) + data = d['data'] if pg.get_bytea_escaped(): - self.assertNotEqual(r['data'], s) - r['data'] = pg.unescape_bytea(r['data']) - self.assertIsInstance(r['data'], bytes) - self.assertEqual(r['data'], s) - r['data'] = None - r = self.db.upsert('bytea_test', r) - self.assertIsInstance(r, dict) - self.assertIn('n', r) - self.assertEqual(r['n'], 7) - self.assertIn('data', r) - self.assertIsNone(r['data']) + self.assertNotEqual(data, s) + self.assertIsInstance(data, str) + data = pg.unescape_bytea(data) # type: ignore + self.assertIsInstance(data, bytes) + self.assertEqual(data, s) + d['data'] = None + d = self.db.upsert('bytea_test', d) + self.assertIsInstance(d, dict) + self.assertIn('n', d) + self.assertEqual(d['n'], 7) + self.assertIn('data', d) + self.assertIsNone(d['data']) def test_insert_get_json(self): self.create_table('json_test', 'n smallint primary key, data json') diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 33c2f6f9..19214c5d 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -287,11 +287,11 @@ def test_parser_nested(self): def test_parser_too_deeply_nested(self): f = pg.cast_array for n in 3, 5, 9, 12, 16, 32, 64, 256: - r = '{' * n + 'a,b,c' + '}' * n + s = '{' * n + 'a,b,c' + '}' * n if n > 16: # hard coded maximum depth - self.assertRaises(ValueError, f, r) + self.assertRaises(ValueError, f, s) else: - r = f(r) + r = f(s) for _i in range(n - 1): self.assertIsInstance(r, list) self.assertEqual(len(r), 1) @@ -537,9 +537,9 @@ def test_parser_nested(self): def test_parser_many_elements(self): f = pg.cast_record for n in 3, 5, 9, 12, 16, 32, 64, 256: - r = ','.join(map(str, range(n))) - r = f'({r})' - r = f(r, int) + s = ','.join(map(str, range(n))) + s = f'({s})' + r = f(s, int) self.assertEqual(r, tuple(range(n))) def test_parser_cast_uniform(self): @@ -877,27 +877,27 @@ class TestEscapeFunctions(unittest.TestCase): def test_escape_string(self): f = pg.escape_string - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f('plain') - self.assertIsInstance(r, str) - self.assertEqual(r, 'plain') - r = f("that's cheese") - self.assertIsInstance(r, str) - self.assertEqual(r, "that''s cheese") + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + s = f("that's cheese") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheese") def test_escape_bytea(self): f = pg.escape_bytea - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f('plain') - self.assertIsInstance(r, str) - self.assertEqual(r, 'plain') - r = f("that's cheese") - self.assertIsInstance(r, str) - self.assertEqual(r, "that''s cheese") + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + s = f("that's cheese") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheese") def test_unescape_bytea(self): f = pg.unescape_bytea diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 6838d03a..2e731c6e 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -869,7 +869,7 @@ def __init__(self, value: Any) -> None: def __str__(self) -> str: return str(self.value).replace('.', ',') - + self.assertTrue(pgdb.decimal_type(CustomDecimal) is CustomDecimal) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) diff --git a/tox.ini b/tox.ini index 322a3f32..eae93234 100644 --- a/tox.ini +++ b/tox.ini @@ -7,20 +7,20 @@ envlist = py3{7,8,9,10,11},ruff,mypy,cformat,docs basepython = python3.11 deps = ruff>=0.0.287 commands = - ruff setup.py pg.py pgdb.py tests + ruff setup.py pg pgdb tests [testenv:mypy] basepython = python3.11 deps = mypy>=1.5.1 commands = - mypy pg.py pgdb.py tests + mypy pg pgdb tests [testenv:cformat] basepython = python3.11 allowlist_externals = sh commands = - sh -c "! (clang-format --style=file -n *.c 2>&1 | tee /dev/tty | grep format-violations)" + sh -c "! (clang-format --style=file -n ext/*.c 2>&1 | tee /dev/tty | grep format-violations)" [testenv:docs] basepython = python3.11 From 8758cdaa0e0230f452918e80fe44305d522c8e91 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 5 Sep 2023 00:01:05 +0200 Subject: [PATCH 154/194] Fix manifest file --- MANIFEST.in | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 4ff1c2b6..8d4bbd33 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,11 +1,9 @@ include setup.py -recursive-include pg pgdb tests *.py - -include pg/*.pyi -include pg/py.typed -include pgdb/py.typed +recursive-include pg *.py *.pyi py.typed +recursive-include pgdb *.py py.typed +recursive-include tests *.py include ext/*.c include ext/*.h From 5f861cbe39c19cd92e9cc7180e3128cee5a2bc6f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 5 Sep 2023 00:13:11 +0200 Subject: [PATCH 155/194] Use organization for docs, mention myself in README --- README.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index 150effb5..a010b944 100644 --- a/README.rst +++ b/README.rst @@ -7,8 +7,10 @@ powerful PostgreSQL features from Python. PyGreSQL should run on most platforms where PostgreSQL and Python is running. It is based on the PyGres95 code written by Pascal Andre. -D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 +D'Arcy J. M. Cain renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. +Christoph Zwerschke volunteered as another maintainer and has been the main +contributor since version 3.7 of PyGreSQL. The following Python versions are supported: @@ -16,7 +18,6 @@ The following Python versions are supported: * PyGreSQL 5.x: Python 2 and Python 3 * PyGreSQL 6.x and newer: Python 3 only - Installation ------------ @@ -31,6 +32,6 @@ Documentation ------------- The documentation is available at -`pygresql.github.io/PyGreSQL/ `_ -and at `pygresql.readthedocs.io `_, +`pygresql.github.io/ `_ and at +`pygresql.readthedocs.io `_, where you can also find the documentation for older versions. From 5252d13164c8b50f805ce4e8cb9a23c7190a6088 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 5 Sep 2023 13:14:54 +0200 Subject: [PATCH 156/194] Split pg package into submodules Also fix a small issue in the code for adaptation of records. --- pg/__init__.py | 2728 +----------------------------- pg/adapt.py | 680 ++++++++ pg/attrs.py | 35 + pg/cast.py | 436 +++++ pg/core.py | 135 ++ pg/db.py | 1332 +++++++++++++++ pg/error.py | 35 + pg/helpers.py | 98 ++ pg/notify.py | 149 ++ pg/tz.py | 21 + tests/test_classic_attrdict.py | 100 ++ tests/test_classic_connection.py | 10 +- tests/test_classic_dbwrapper.py | 167 +- tests/test_classic_functions.py | 9 +- 14 files changed, 3096 insertions(+), 2839 deletions(-) create mode 100644 pg/adapt.py create mode 100644 pg/attrs.py create mode 100644 pg/cast.py create mode 100644 pg/core.py create mode 100644 pg/db.py create mode 100644 pg/error.py create mode 100644 pg/helpers.py create mode 100644 pg/notify.py create mode 100644 pg/tz.py create mode 100644 tests/test_classic_attrdict.py diff --git a/pg/__init__.py b/pg/__init__.py index 0740db20..e0e1b214 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -22,70 +22,9 @@ from __future__ import annotations -import select -import weakref -from collections import namedtuple -from contextlib import suppress -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from functools import lru_cache, partial -from inspect import signature -from json import dumps as jsonencode -from json import loads as jsondecode -from math import isinf, isnan -from operator import itemgetter -from re import compile as regex -from types import MappingProxyType -from typing import ( - Any, - Callable, - ClassVar, - Generator, - Iterator, - List, - Mapping, - NamedTuple, - Sequence, - TypeVar, -) -from uuid import UUID - -try: - from ._pg import version -except ImportError as e: # noqa: F841 - import os - libpq = 'libpq.' - if os.name == 'nt': - libpq += 'dll' - import sys - paths = [path for path in os.environ["PATH"].split(os.pathsep) - if os.path.exists(os.path.join(path, libpq))] - if sys.version_info >= (3, 8): - # see https://docs.python.org/3/whatsnew/3.8.html#ctypes - add_dll_dir = os.add_dll_directory # type: ignore - for path in paths: - with add_dll_dir(os.path.abspath(path)): - try: - from ._pg import version - except ImportError: - pass - else: - del version - e = None # type: ignore - break - if paths: - libpq = 'compatible ' + libpq - else: - libpq += 'so' - if e: - raise ImportError( - "Cannot import shared library for PyGreSQL,\n" - f"probably because no {libpq} is installed.\n{e}") from e -else: - del version - -# import objects from extension module -from ._pg import ( +from .adapt import Adapter, Bytea, Hstore, Json, Literal +from .cast import Typecasts, get_typecast, set_typecast +from .core import ( INV_READ, INV_WRITE, POLLING_FAILED, @@ -155,6 +94,9 @@ unescape_bytea, version, ) +from .db import DB +from .helpers import init_core, set_row_factory_size +from .notify import NotificationHandler __version__ = version @@ -185,2661 +127,9 @@ 'set_datestyle', 'set_decimal', 'set_decimal_point', 'set_defbase', 'set_defhost', 'set_defopt', 'set_defpasswd', 'set_defport', 'set_defuser', - 'set_jsondecode', 'set_query_helpers', 'set_typecast', + 'set_jsondecode', 'set_query_helpers', + 'set_row_factory_size', 'set_typecast', 'version', '__version__', ] -# Auxiliary classes and functions that are independent of a DB connection: - -SomeNamedTuple = Any # alias for accessing arbitrary named tuples - -def get_args(func: Callable) -> list: - return list(signature(func).parameters) - - -# time zones used in Postgres timestamptz output -_timezones: dict[str, str] = { - 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', - 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', - 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' -} - - -def _timezone_as_offset(tz: str) -> str: - if tz.startswith(('+', '-')): - if len(tz) < 5: - return tz + '00' - return tz.replace(':', '') - return _timezones.get(tz, '+0000') - - -def _oid_key(table: str) -> str: - """Build oid key from a table name.""" - return f'oid({table})' - - -class Bytea(bytes): - """Wrapper class for marking Bytea values.""" - - -class Hstore(dict): - """Wrapper class for marking hstore values.""" - - _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') - - @classmethod - def _quote(cls, s: Any) -> str: - if s is None: - return 'NULL' - if not isinstance(s, str): - s = str(s) - if not s: - return '""' - s = s.replace('"', '\\"') - if cls._re_quote.search(s): - s = f'"{s}"' - return s - - def __str__(self) -> str: - """Create a printable representation of the hstore value.""" - q = self._quote - return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) - - -class Json: - """Wrapper class for marking Json values.""" - - def __init__(self, obj: Any, - encode: Callable[[Any], str] | None = None) -> None: - """Initialize the JSON object.""" - self.obj = obj - self.encode = encode or jsonencode - - def __str__(self) -> str: - """Create a printable representation of the JSON object.""" - obj = self.obj - if isinstance(obj, str): - return obj - return self.encode(obj) - - -class _SimpleTypes(dict): - """Dictionary mapping pg_type names to simple type names. - - The corresponding Python types and simple names are also mapped. - """ - - _type_aliases: Mapping[str, list[str | type]] = MappingProxyType({ - 'bool': [bool], - 'bytea': [Bytea], - 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', - 'abstime', 'reltime', # these are very old - 'datetime', 'timedelta', # these do not really exist - date, time, datetime, timedelta], - 'float': ['float4', 'float8', float], - 'int': ['cid', 'int2', 'int4', 'int8', 'oid', 'xid', int], - 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], - 'num': ['numeric', Decimal], 'money': [], - 'text': ['bpchar', 'char', 'name', 'varchar', bytes, str] - }) - - # noinspection PyMissingConstructor - def __init__(self) -> None: - """Initialize type mapping.""" - for typ, keys in self._type_aliases.items(): - keys = [typ, *keys] - for key in keys: - self[key] = typ - if isinstance(key, str): - self[f'_{key}'] = f'{typ}[]' - elif not isinstance(key, tuple): - self[List[key]] = f'{typ}[]' # type: ignore - - @staticmethod - def __missing__(key: str) -> str: - """Unmapped types are interpreted as text.""" - return 'text' - - def get_type_dict(self) -> dict[type, str]: - """Get a plain dictionary of only the types.""" - return {key: typ for key, typ in self.items() - if not isinstance(key, (str, tuple))} - - -_simpletypes = _SimpleTypes() -_simple_type_dict = _simpletypes.get_type_dict() - - -def _quote_if_unqualified(param: str, name: int | str) -> str: - """Quote parameter representing a qualified name. - - Puts a quote_ident() call around the given parameter unless - the name contains a dot, in which case the name is ambiguous - (could be a qualified name or just a name with a dot in it) - and must be quoted manually by the caller. - """ - if isinstance(name, str) and '.' not in name: - return f'quote_ident({param})' - return param - - -class _ParameterList(list): - """Helper class for building typed parameter lists.""" - - adapt: Callable - - def add(self, value: Any, typ:Any = None) -> str: - """Typecast value with known database type and build parameter list. - - If this is a literal value, it will be returned as is. Otherwise, a - placeholder will be returned and the parameter list will be augmented. - """ - # noinspection PyUnresolvedReferences - value = self.adapt(value, typ) - if isinstance(value, Literal): - return value - self.append(value) - return f'${len(self)}' - - -class Literal(str): - """Wrapper class for marking literal SQL values.""" - - -class AttrDict(dict): - """Simple read-only ordered dictionary for storing attribute names.""" - - def __init__(self, *args: Any, **kw: Any) -> None: - self._read_only = False - super().__init__(*args, **kw) - self._read_only = True - error = self._read_only_error - self.clear = self.update = error # type: ignore - self.pop = self.setdefault = self.popitem = error # type: ignore - - def __setitem__(self, key: str, value: Any) -> None: - if self._read_only: - self._read_only_error() - super().__setitem__(key, value) - - def __delitem__(self, key: str) -> None: - if self._read_only: - self._read_only_error() - super().__delitem__(key) - - @staticmethod - def _read_only_error(*_args: Any, **_kw: Any) -> Any: - raise TypeError('This object is read-only') - - -class Adapter: - """Class providing methods for adapting parameters to the database.""" - - _bool_true_values = frozenset('t true 1 y yes on'.split()) - - _date_literals = frozenset( - 'current_date current_time' - ' current_timestamp localtime localtimestamp'.split()) - - _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$') - _re_record_quote = regex(r'[(,"\\]') - _re_array_escape = _re_record_escape = regex(r'(["\\])') - - def __init__(self, db: DB): - """Initialize the adapter object with the given connection.""" - self.db = weakref.proxy(db) - - @classmethod - def _adapt_bool(cls, v: Any) -> str | None: - """Adapt a boolean parameter.""" - if isinstance(v, str): - if not v: - return None - v = v.lower() in cls._bool_true_values - return 't' if v else 'f' - - @classmethod - def _adapt_date(cls, v: Any) -> Any: - """Adapt a date parameter.""" - if not v: - return None - if isinstance(v, str) and v.lower() in cls._date_literals: - return Literal(v) - return v - - @staticmethod - def _adapt_num(v: Any) -> Any: - """Adapt a numeric parameter.""" - if not v and v != 0: - return None - return v - - _adapt_int = _adapt_float = _adapt_money = _adapt_num - - def _adapt_bytea(self, v: Any) -> str: - """Adapt a bytea parameter.""" - return self.db.escape_bytea(v) - - def _adapt_json(self, v: Any) -> str | None: - """Adapt a json parameter.""" - if not v: - return None - if isinstance(v, str): - return v - if isinstance(v, Json): - return str(v) - return self.db.encode_json(v) - - def _adapt_hstore(self, v: Any) -> str | None: - """Adapt a hstore parameter.""" - if not v: - return None - if isinstance(v, str): - return v - if isinstance(v, Hstore): - return str(v) - if isinstance(v, dict): - return str(Hstore(v)) - raise TypeError(f'Hstore parameter {v} has wrong type') - - def _adapt_uuid(self, v: Any) -> str | None: - """Adapt a UUID parameter.""" - if not v: - return None - if isinstance(v, str): - return v - return str(v) - - @classmethod - def _adapt_text_array(cls, v: Any) -> str: - """Adapt a text type array parameter.""" - if isinstance(v, list): - adapt = cls._adapt_text_array - return '{' + ','.join(adapt(v) for v in v) + '}' - if v is None: - return 'null' - if not v: - return '""' - v = str(v) - if cls._re_array_quote.search(v): - v = cls._re_array_escape.sub(r'\\\1', v) - v = f'"{v}"' - return v - - _adapt_date_array = _adapt_text_array - - @classmethod - def _adapt_bool_array(cls, v: Any) -> str: - """Adapt a boolean array parameter.""" - if isinstance(v, list): - adapt = cls._adapt_bool_array - return '{' + ','.join(adapt(v) for v in v) + '}' - if v is None: - return 'null' - if isinstance(v, str): - if not v: - return 'null' - v = v.lower() in cls._bool_true_values - return 't' if v else 'f' - - @classmethod - def _adapt_num_array(cls, v: Any) -> str: - """Adapt a numeric array parameter.""" - if isinstance(v, list): - adapt = cls._adapt_num_array - v = '{' + ','.join(adapt(v) for v in v) + '}' - if not v and v != 0: - return 'null' - return str(v) - - _adapt_int_array = _adapt_float_array = _adapt_money_array = \ - _adapt_num_array - - def _adapt_bytea_array(self, v: Any) -> bytes: - """Adapt a bytea array parameter.""" - if isinstance(v, list): - return b'{' + b','.join( - self._adapt_bytea_array(v) for v in v) + b'}' - if v is None: - return b'null' - return self.db.escape_bytea(v).replace(b'\\', b'\\\\') - - def _adapt_json_array(self, v: Any) -> str: - """Adapt a json array parameter.""" - if isinstance(v, list): - adapt = self._adapt_json_array - return '{' + ','.join(adapt(v) for v in v) + '}' - if not v: - return 'null' - if not isinstance(v, str): - v = self.db.encode_json(v) - if self._re_array_quote.search(v): - v = self._re_array_escape.sub(r'\\\1', v) - v = f'"{v}"' - return v - - def _adapt_record(self, v: Any, typ: Any) -> str: - """Adapt a record parameter with given type.""" - typ = self.get_attnames(typ).values() - if len(typ) != len(v): - raise TypeError(f'Record parameter {v} has wrong size') - adapt = self.adapt - value = [] - for v, t in zip(v, typ): # noqa: B020 - v = adapt(v, t) - if v is None: - v = '' - elif not v: - v = '""' - else: - if isinstance(v, bytes): - if str is not bytes: - v = v.decode('ascii') - else: - v = str(v) - if self._re_record_quote.search(v): - v = self._re_record_escape.sub(r'\\\1', v) - v = f'"{v}"' - value.append(v) - v = ','.join(value) - return f'({v})' - - def adapt(self, value: Any, typ: Any = None) -> str: - """Adapt a value with known database type.""" - if value is not None and not isinstance(value, Literal): - if typ: - simple = self.get_simple_name(typ) - else: - typ = simple = self.guess_simple_type(value) or 'text' - pg_str = getattr(value, '__pg_str__', None) - if pg_str: - value = pg_str(typ) - if simple == 'text': - pass - elif simple == 'record': - if isinstance(value, tuple): - value = self._adapt_record(value, typ) - elif simple.endswith('[]'): - if isinstance(value, list): - adapt = getattr(self, f'_adapt_{simple[:-2]}_array') - value = adapt(value) - else: - adapt = getattr(self, f'_adapt_{simple}') - value = adapt(value) - return value - - @staticmethod - def simple_type(name: str) -> DbType: - """Create a simple database type with given attribute names.""" - typ = DbType(name) - typ.simple = name - return typ - - @staticmethod - def get_simple_name(typ: Any) -> str: - """Get the simple name of a database type.""" - if isinstance(typ, DbType): - # noinspection PyUnresolvedReferences - return typ.simple - return _simpletypes[typ] - - @staticmethod - def get_attnames(typ: Any) -> dict[str, dict[str, str]]: - """Get the attribute names of a composite database type.""" - if isinstance(typ, DbType): - return typ.attnames - return {} - - @classmethod - def guess_simple_type(cls, value: Any) -> str | None: - """Try to guess which database type the given value has.""" - # optimize for most frequent types - try: - return _simple_type_dict[type(value)] - except KeyError: - pass - if isinstance(value, (bytes, str)): - return 'text' - if isinstance(value, bool): - return 'bool' - if isinstance(value, int): - return 'int' - if isinstance(value, float): - return 'float' - if isinstance(value, Decimal): - return 'num' - if isinstance(value, (date, time, datetime, timedelta)): - return 'date' - if isinstance(value, Bytea): - return 'bytea' - if isinstance(value, Json): - return 'json' - if isinstance(value, Hstore): - return 'hstore' - if isinstance(value, UUID): - return 'uuid' - if isinstance(value, list): - return (cls.guess_simple_base_type(value) or 'text') + '[]' - if isinstance(value, tuple): - simple_type = cls.simple_type - guess = cls.guess_simple_type - - # noinspection PyUnusedLocal - def get_attnames(self: DbType) -> AttrDict: - return AttrDict((str(n + 1), simple_type(guess(v) or 'text')) - for n, v in enumerate(value)) - - typ = simple_type('record') - typ._get_attnames = get_attnames - return typ - return None - - @classmethod - def guess_simple_base_type(cls, value: Any) -> str | None: - """Try to guess the base type of a given array.""" - for v in value: - if isinstance(v, list): - typ = cls.guess_simple_base_type(v) - else: - typ = cls.guess_simple_type(v) - if typ: - return typ - return None - - def adapt_inline(self, value: Any, nested: bool=False) -> Any: - """Adapt a value that is put into the SQL and needs to be quoted.""" - if value is None: - return 'NULL' - if isinstance(value, Literal): - return value - if isinstance(value, Bytea): - value = self.db.escape_bytea(value).decode('ascii') - elif isinstance(value, (datetime, date, time, timedelta)): - value = str(value) - if isinstance(value, (bytes, str)): - value = self.db.escape_string(value) - return f"'{value}'" - if isinstance(value, bool): - return 'true' if value else 'false' - if isinstance(value, float): - if isinf(value): - return "'-Infinity'" if value < 0 else "'Infinity'" - if isnan(value): - return "'NaN'" - return value - if isinstance(value, (int, Decimal)): - return value - if isinstance(value, list): - q = self.adapt_inline - s = '[{}]' if nested else 'ARRAY[{}]' - return s.format(','.join(str(q(v, nested=True)) for v in value)) - if isinstance(value, tuple): - q = self.adapt_inline - return '({})'.format(','.join(str(q(v)) for v in value)) - if isinstance(value, Json): - value = self.db.escape_string(str(value)) - return f"'{value}'::json" - if isinstance(value, Hstore): - value = self.db.escape_string(str(value)) - return f"'{value}'::hstore" - pg_repr = getattr(value, '__pg_repr__', None) - if not pg_repr: - raise InterfaceError( - f'Do not know how to adapt type {type(value)}') - value = pg_repr() - if isinstance(value, (tuple, list)): - value = self.adapt_inline(value) - return value - - def parameter_list(self) -> _ParameterList: - """Return a parameter list for parameters with known database types. - - The list has an add(value, typ) method that will build up the - list and return either the literal value or a placeholder. - """ - params = _ParameterList() - params.adapt = self.adapt - return params - - def format_query(self, command: str, - values: list | tuple | dict | None = None, - types: list | tuple | dict | None = None, - inline: bool=False - ) -> tuple[str, _ParameterList]: - """Format a database query using the given values and types. - - The optional types describe the values and must be passed as a list, - tuple or string (that will be split on whitespace) when values are - passed as a list or tuple, or as a dict if values are passed as a dict. - - If inline is set to True, then parameters will be passed inline - together with the query string. - """ - params = self.parameter_list() - if not values: - return command, params - if inline and types: - raise ValueError('Typed parameters must be sent separately') - if isinstance(values, (list, tuple)): - if inline: - adapt = self.adapt_inline - seq_literals = [adapt(value) for value in values] - else: - add = params.add - if types: - if isinstance(types, str): - types = types.split() - if (not isinstance(types, (list, tuple)) - or len(types) != len(values)): - raise TypeError('The values and types do not match') - seq_literals = [add(value, typ) - for value, typ in zip(values, types)] - else: - seq_literals = [add(value) for value in values] - command %= tuple(seq_literals) - elif isinstance(values, dict): - # we want to allow extra keys in the dictionary, - # so we first must find the values actually used in the command - used_values = {} - map_literals = dict.fromkeys(values, '') - for key in values: - del map_literals[key] - try: - command % map_literals - except KeyError: - used_values[key] = values[key] # pyright: ignore - map_literals[key] = '' - if inline: - adapt = self.adapt_inline - map_literals = {key: adapt(value) - for key, value in used_values.items()} - else: - add = params.add - if types: - if not isinstance(types, dict): - raise TypeError('The values and types do not match') - map_literals = {key: add(used_values[key], types.get(key)) - for key in sorted(used_values)} - else: - map_literals = {key: add(used_values[key]) - for key in sorted(used_values)} - command %= map_literals - else: - raise TypeError('The values must be passed as tuple, list or dict') - return command, params - - -def cast_bool(value: str) -> Any: - """Cast a boolean value.""" - if not get_bool(): - return value - return value[0] == 't' - - -def cast_json(value: str) -> Any: - """Cast a JSON value.""" - cast = get_jsondecode() - if not cast: - return value - return cast(value) - - -def cast_num(value: str) -> Any: - """Cast a numeric value.""" - return (get_decimal() or float)(value) - - -def cast_money(value: str) -> Any: - """Cast a money value.""" - point = get_decimal_point() - if not point: - return value - if point != '.': - value = value.replace(point, '.') - value = value.replace('(', '-') - value = ''.join(c for c in value if c.isdigit() or c in '.-') - return (get_decimal() or float)(value) - - -def cast_int2vector(value: str) -> list[int]: - """Cast an int2vector value.""" - return [int(v) for v in value.split()] - - -def cast_date(value: str, connection: DB) -> Any: - """Cast a date value.""" - # The output format depends on the server setting DateStyle. The default - # setting ISO and the setting for German are actually unambiguous. The - # order of days and months in the other two settings is however ambiguous, - # so at least here we need to consult the setting to properly parse values. - if value == '-infinity': - return date.min - if value == 'infinity': - return date.max - values = value.split() - if values[-1] == 'BC': - return date.min - value = values[0] - if len(value) > 10: - return date.max - format = connection.date_format() - return datetime.strptime(value, format).date() - - -def cast_time(value: str) -> Any: - """Cast a time value.""" - format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - return datetime.strptime(value, format).time() - - -_re_timezone = regex('(.*)([+-].*)') - - -def cast_timetz(value: str) -> Any: - """Cast a timetz value.""" - m = _re_timezone.match(value) - if m: - value, tz = m.groups() - else: - tz = '+0000' - format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - value += _timezone_as_offset(tz) - format += '%z' - return datetime.strptime(value, format).timetz() - - -def cast_timestamp(value: str, connection: DB) -> Any: - """Cast a timestamp value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - values = value.split() - if values[-1] == 'BC': - return datetime.min - format = connection.date_format() - if format.endswith('-%Y') and len(values) > 2: - values = values[1:5] - if len(values[3]) > 4: - return datetime.max - formats = ['%d %b' if format.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] - else: - if len(values[0]) > 10: - return datetime.max - formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] - return datetime.strptime(' '.join(values), ' '.join(formats)) - - -def cast_timestamptz(value: str, connection: DB) -> Any: - """Cast a timestamptz value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - values = value.split() - if values[-1] == 'BC': - return datetime.min - format = connection.date_format() - if format.endswith('-%Y') and len(values) > 2: - values = values[1:] - if len(values[3]) > 4: - return datetime.max - formats = ['%d %b' if format.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] - values, tz = values[:-1], values[-1] - else: - if format.startswith('%Y-'): - m = _re_timezone.match(values[1]) - if m: - values[1], tz = m.groups() - else: - tz = '+0000' - else: - values, tz = values[:-1], values[-1] - if len(values[0]) > 10: - return datetime.max - formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] - values.append(_timezone_as_offset(tz)) - formats.append('%z') - return datetime.strptime(' '.join(values), ' '.join(formats)) - - -_re_interval_sql_standard = regex( - '(?:([+-])?([0-9]+)-([0-9]+) ?)?' - '(?:([+-]?[0-9]+)(?!:) ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres = regex( - '(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres_verbose = regex( - '@ ?(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-]?[0-9]+) ?hours? ?)?' - '(?:([+-]?[0-9]+) ?mins? ?)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') - -_re_interval_iso_8601 = regex( - 'P(?:([+-]?[0-9]+)Y)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-]?[0-9]+)D)?' - '(?:T(?:([+-]?[0-9]+)H)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') - - -def cast_interval(value: str) -> timedelta: - """Cast an interval value.""" - # The output format depends on the server setting IntervalStyle, but it's - # not necessary to consult this setting to parse it. It's faster to just - # check all possible formats, and there is no ambiguity here. - m = _re_interval_iso_8601.match(value) - if m: - s = [v or '0' for v in m.groups()] - secs_ago = s.pop(5) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if secs_ago: - secs = -secs - usecs = -usecs - else: - m = _re_interval_postgres_verbose.match(value) - if m: - s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) - secs_ago = s.pop(5) == '-' - d = [-int(v) for v in s] if ago else [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if secs_ago: - secs = - secs - usecs = -usecs - else: - m = _re_interval_postgres.match(value) - if m and any(m.groups()): - s = [v or '0' for v in m.groups()] - hours_ago = s.pop(3) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - m = _re_interval_sql_standard.match(value) - if m and any(m.groups()): - s = [v or '0' for v in m.groups()] - years_ago = s.pop(0) == '-' - hours_ago = s.pop(3) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if years_ago: - years = -years - mons = -mons - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - raise ValueError(f'Cannot parse interval: {value}') - days += 365 * years + 30 * mons - return timedelta(days=days, hours=hours, minutes=mins, - seconds=secs, microseconds=usecs) - - -class Typecasts(dict): - """Dictionary mapping database types to typecast functions. - - The cast functions get passed the string representation of a value in - the database which they need to convert to a Python object. The - passed string will never be None since NULL values are already - handled before the cast function is called. - - Note that the basic types are already handled by the C extension. - They only need to be handled here as record or array components. - """ - - # the default cast functions - # (str functions are ignored but have been added for faster access) - defaults: ClassVar[dict[str, Callable]] = { - 'char': str, 'bpchar': str, 'name': str, - 'text': str, 'varchar': str, 'sql_identifier': str, - 'bool': cast_bool, 'bytea': unescape_bytea, - 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, - 'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json, - 'float4': float, 'float8': float, - 'numeric': cast_num, 'money': cast_money, - 'date': cast_date, 'interval': cast_interval, - 'time': cast_time, 'timetz': cast_timetz, - 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, - 'int2vector': cast_int2vector, 'uuid': UUID, - 'anyarray': cast_array, 'record': cast_record} # pyright: ignore - - connection: DB | None = None # set in a connection specific instance - - def __missing__(self, typ: str) -> Callable | None: - """Create a cast function if it is not cached. - - Note that this class never raises a KeyError, - but returns None when no special cast function exists. - """ - if not isinstance(typ, str): - raise TypeError(f'Invalid type: {typ}') - cast: Callable | None = self.defaults.get(typ) - if cast: - # store default for faster access - cast = self._add_connection(cast) - self[typ] = cast - elif typ.startswith('_'): - base_cast = self[typ[1:]] - cast = self.create_array_cast(base_cast) - if base_cast: - self[typ] = cast - else: - attnames = self.get_attnames(typ) - if attnames: - casts = [self[v.pgtype] for v in attnames.values()] - cast = self.create_record_cast(typ, attnames, casts) - self[typ] = cast - return cast - - @staticmethod - def _needs_connection(func: Callable) -> bool: - """Check if a typecast function needs a connection argument.""" - try: - args = get_args(func) - except (TypeError, ValueError): - return False - return 'connection' in args[1:] - - def _add_connection(self, cast: Callable) -> Callable: - """Add a connection argument to the typecast function if necessary.""" - if not self.connection or not self._needs_connection(cast): - return cast - return partial(cast, connection=self.connection) - - def get(self, typ: str, default: Callable | None = None # type: ignore - ) -> Callable | None: - """Get the typecast function for the given database type.""" - return self[typ] or default - - def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: - """Set a typecast function for the specified database type(s).""" - if isinstance(typ, str): - typ = [typ] - if cast is None: - for t in typ: - self.pop(t, None) - self.pop(f'_{t}', None) - else: - if not callable(cast): - raise TypeError("Cast parameter must be callable") - for t in typ: - self[t] = self._add_connection(cast) - self.pop(f'_{t}', None) - - def reset(self, typ: str | Sequence[str] | None = None) -> None: - """Reset the typecasts for the specified type(s) to their defaults. - - When no type is specified, all typecasts will be reset. - """ - if typ is None: - self.clear() - else: - if isinstance(typ, str): - typ = [typ] - for t in typ: - self.pop(t, None) - - @classmethod - def get_default(cls, typ: str) -> Any: - """Get the default typecast function for the given database type.""" - return cls.defaults.get(typ) - - @classmethod - def set_default(cls, typ: str | Sequence[str], - cast: Callable | None) -> None: - """Set a default typecast function for the given database type(s).""" - if isinstance(typ, str): - typ = [typ] - defaults = cls.defaults - if cast is None: - for t in typ: - defaults.pop(t, None) - defaults.pop(f'_{t}', None) - else: - if not callable(cast): - raise TypeError("Cast parameter must be callable") - for t in typ: - defaults[t] = cast - defaults.pop(f'_{t}', None) - - # noinspection PyMethodMayBeStatic,PyUnusedLocal - def get_attnames(self, typ: Any) -> AttrDict: - """Return the fields for the given record type. - - This method will be replaced with the get_attnames() method of DbTypes. - """ - return AttrDict() - - # noinspection PyMethodMayBeStatic - def dateformat(self) -> str: - """Return the current date format. - - This method will be replaced with the dateformat() method of DbTypes. - """ - return '%Y-%m-%d' - - def create_array_cast(self, basecast: Callable) -> Callable: - """Create an array typecast for the given base cast.""" - cast_array = self['anyarray'] - - def cast(v: Any) -> list: - return cast_array(v, basecast) - return cast - - def create_record_cast(self, name: str, fields: AttrDict, - casts: list[Callable]) -> Callable: - """Create a named record typecast for the given fields and casts.""" - cast_record = self['record'] - record = namedtuple(name, fields) # type: ignore - - def cast(v: Any) -> record: - # noinspection PyArgumentList - return record(*cast_record(v, casts)) - return cast - - -def get_typecast(typ: str) -> Callable | None: - """Get the global typecast function for the given database type.""" - return Typecasts.get_default(typ) - - -def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: - """Set a global typecast function for the given database type(s). - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call db.db_types.reset_typecast(). - """ - Typecasts.set_default(typ, cast) - - -class DbType(str): - """Class augmenting the simple type name with additional info. - - The following additional information is provided: - - oid: the PostgreSQL type OID - pgtype: the internal PostgreSQL data type name - regtype: the registered PostgreSQL data type name - simple: the more coarse-grained PyGreSQL type name - typlen: the internal size, negative if variable - typtype: b = base type, c = composite type etc. - category: A = Array, b = Boolean, C = Composite etc. - delim: delimiter for array types - relid: corresponding table for composite types - attnames: attributes for composite types - """ - - oid: int - pgtype: str - regtype: str - simple: str - typlen: int - typtype: str - category: str - delim: str - relid: int - - _get_attnames: Callable[[DbType], AttrDict] - - @property - def attnames(self) -> AttrDict: - """Get names and types of the fields of a composite type.""" - # noinspection PyUnresolvedReferences - return self._get_attnames(self) - - -class DbTypes(dict): - """Cache for PostgreSQL data types. - - This cache maps type OIDs and names to DbType objects containing - information on the associated database type. - """ - - _num_types = frozenset('int float num money int2 int4 int8' - ' float4 float8 numeric money'.split()) - - def __init__(self, db: DB) -> None: - """Initialize type cache for connection.""" - super().__init__() - self._db = weakref.proxy(db) - self._regtypes = False - self._typecasts = Typecasts() - self._typecasts.get_attnames = self.get_attnames # type: ignore - self._typecasts.connection = self._db - self._query_pg_type = ( - "SELECT oid, typname, oid::pg_catalog.regtype," - " typlen, typtype, typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type" - " WHERE oid OPERATOR(pg_catalog.=) {}::pg_catalog.regtype") - - def add(self, oid: int, pgtype: str, regtype: str, - typlen: int, typtype: str, category: str, delim: str, relid: int - ) -> DbType: - """Create a PostgreSQL type name with additional info.""" - if oid in self: - return self[oid] - simple = 'record' if relid else _simpletypes[pgtype] - typ = DbType(regtype if self._regtypes else simple) - typ.oid = oid - typ.simple = simple - typ.pgtype = pgtype - typ.regtype = regtype - typ.typlen = typlen - typ.typtype = typtype - typ.category = category - typ.delim = delim - typ.relid = relid - typ._get_attnames = self.get_attnames # type: ignore - return typ - - def __missing__(self, key: int | str) -> DbType: - """Get the type info from the database if it is not cached.""" - try: - cmd = self._query_pg_type.format(_quote_if_unqualified('$1', key)) - res = self._db.query(cmd, (key,)).getresult() - except ProgrammingError: - res = None - if not res: - raise KeyError(f'Type {key} could not be found') - res = res[0] - typ = self.add(*res) - self[typ.oid] = self[typ.pgtype] = typ - return typ - - def get(self, key: int | str, # type: ignore - default: DbType | None = None) -> DbType | None: - """Get the type even if it is not cached.""" - try: - return self[key] - except KeyError: - return default - - def get_attnames(self, typ: Any) -> AttrDict | None: - """Get names and types of the fields of a composite type.""" - if not isinstance(typ, DbType): - typ = self.get(typ) - if not typ: - return None - if not typ.relid: - return None - return self._db.get_attnames(typ.relid, with_oid=False) - - def get_typecast(self, typ: Any) -> Callable | None: - """Get the typecast function for the given database type.""" - return self._typecasts.get(typ) - - def set_typecast(self, typ: str | Sequence[str], cast: Callable) -> None: - """Set a typecast function for the specified database type(s).""" - self._typecasts.set(typ, cast) - - def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: - """Reset the typecast function for the specified database type(s).""" - self._typecasts.reset(typ) - - def typecast(self, value: Any, typ: str) -> Any: - """Cast the given value according to the given database type.""" - if value is None: - # for NULL values, no typecast is necessary - return None - if not isinstance(typ, DbType): - db_type = self.get(typ) - if db_type: - typ = db_type.pgtype - cast = self.get_typecast(typ) if typ else None - if not cast or cast is str: - # no typecast is necessary - return value - return cast(value) - - -# The result rows for database operations are returned as named tuples -# by default. Since creating namedtuple classes is a somewhat expensive -# operation, we cache up to 1024 of these classes by default. - -# noinspection PyUnresolvedReferences -@lru_cache(maxsize=1024) -def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: - """Get a namedtuple factory for row results with the given names.""" - try: - return namedtuple('Row', names, rename=True)._make # type: ignore - except ValueError: # there is still a problem with the field names - names = [f'column_{n}' for n in range(len(names))] - return namedtuple('Row', names)._make # type: ignore - - -def set_row_factory_size(maxsize: int | None) -> None: - """Change the size of the namedtuple factory cache. - - If maxsize is set to None, the cache can grow without bound. - """ - # noinspection PyGlobalUndefined - global _row_factory - _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) - - -# Helper functions used by the query object - -def _dictiter(q: Query) -> Generator[dict[str, Any], None, None]: - """Get query result as an iterator of dictionaries.""" - fields: tuple[str, ...] = q.listfields() - for r in q: - yield dict(zip(fields, r)) - - -def _namediter(q: Query) -> Generator[SomeNamedTuple, None, None]: - """Get query result as an iterator of named tuples.""" - row = _row_factory(q.listfields()) - for r in q: - yield row(r) - - -def _namednext(q: Query) -> SomeNamedTuple: - """Get next row from query result as a named tuple.""" - return _row_factory(q.listfields())(next(q)) - - -def _scalariter(q: Query) -> Generator[Any, None, None]: - """Get query result as an iterator of scalar values.""" - for r in q: - yield r[0] - - -class _MemoryQuery: - """Class that embodies a given query result.""" - - result: Any - fields: tuple[str, ...] - - def __init__(self, result: Any, fields: Sequence[str]) -> None: - """Create query from given result rows and field names.""" - self.result = result - self.fields = tuple(fields) - - def listfields(self) -> tuple[str, ...]: - """Return the stored field names of this query.""" - return self.fields - - def getresult(self) -> Any: - """Return the stored result of this query.""" - return self.result - - def __iter__(self) -> Iterator[Any]: - return iter(self.result) - -# Error messages - -E = TypeVar('E', bound=Error) - -def _error(msg: str, cls: type[E]) -> E: - """Return specified error object with empty sqlstate attribute.""" - error = cls(msg) - if isinstance(error, DatabaseError): - error.sqlstate = None - return error - - -def _db_error(msg: str) -> DatabaseError: - """Return DatabaseError.""" - return _error(msg, DatabaseError) - - -def _int_error(msg: str) -> InternalError: - """Return InternalError.""" - return _error(msg, InternalError) - - -def _prg_error(msg: str) -> ProgrammingError: - """Return ProgrammingError.""" - return _error(msg, ProgrammingError) - - -# Initialize the C module - -set_decimal(Decimal) -set_jsondecode(jsondecode) -set_query_helpers(_dictiter, _namediter, _namednext, _scalariter) - - -# The notification handler - -class NotificationHandler: - """A PostgreSQL client-side asynchronous notification handler.""" - - def __init__(self, db: DB, event: str, callback: Callable, - arg_dict: dict | None = None, - timeout: int | float | None = None, - stop_event: str | None = None): - """Initialize the notification handler. - - You must pass a PyGreSQL database connection, the name of an - event (notification channel) to listen for and a callback function. - - You can also specify a dictionary arg_dict that will be passed as - the single argument to the callback function, and a timeout value - in seconds (a floating point number denotes fractions of seconds). - If it is absent or None, the callers will never time out. If the - timeout is reached, the callback function will be called with a - single argument that is None. If you set the timeout to zero, - the handler will poll notifications synchronously and return. - - You can specify the name of the event that will be used to signal - the handler to stop listening as stop_event. By default, it will - be the event name prefixed with 'stop_'. - """ - self.db: DB | None = db - self.event = event - self.stop_event = stop_event or f'stop_{event}' - self.listening = False - self.callback = callback - if arg_dict is None: - arg_dict = {} - self.arg_dict = arg_dict - self.timeout = timeout - - def __del__(self) -> None: - """Delete the notification handler.""" - self.unlisten() - - def close(self) -> None: - """Stop listening and close the connection.""" - if self.db: - self.unlisten() - self.db.close() - self.db = None - - def listen(self) -> None: - """Start listening for the event and the stop event.""" - db = self.db - if db and not self.listening: - db.query(f'listen "{self.event}"') - db.query(f'listen "{self.stop_event}"') - self.listening = True - - def unlisten(self) -> None: - """Stop listening for the event and the stop event.""" - db = self.db - if db and self.listening: - db.query(f'unlisten "{self.event}"') - db.query(f'unlisten "{self.stop_event}"') - self.listening = False - - def notify(self, db: DB | None = None, stop: bool = False, - payload: str | None = None) -> Query | None: - """Generate a notification. - - Optionally, you can pass a payload with the notification. - - If you set the stop flag, a stop notification will be sent that - will cause the handler to stop listening. - - Note: If the notification handler is running in another thread, you - must pass a different database connection since PyGreSQL database - connections are not thread-safe. - """ - if not self.listening: - return None - if not db: - db = self.db - if not db: - return None - event = self.stop_event if stop else self.event - cmd = f'notify "{event}"' - if payload: - cmd += f", '{payload}'" - return db.query(cmd) - - def __call__(self) -> None: - """Invoke the notification handler. - - The handler is a loop that listens for notifications on the event - and stop event channels. When either of these notifications are - received, its associated 'pid', 'event' and 'extra' (the payload - passed with the notification) are inserted into its arg_dict - dictionary and the callback is invoked with this dictionary as - a single argument. When the handler receives a stop event, it - stops listening to both events and return. - - In the special case that the timeout of the handler has been set - to zero, the handler will poll all events synchronously and return. - If will keep listening until it receives a stop event. - - Note: If you run this loop in another thread, don't use the same - database connection for database operations in the main thread. - """ - if not self.db: - return - self.listen() - poll = self.timeout == 0 - rlist = [] if poll else [self.db.fileno()] - while self.db and self.listening: - # noinspection PyUnboundLocalVariable - if poll or select.select(rlist, [], [], self.timeout)[0]: - while self.db and self.listening: - notice = self.db.getnotify() - if not notice: # no more messages - break - event, pid, extra = notice - if event not in (self.event, self.stop_event): - self.unlisten() - raise _db_error( - f'Listening for "{self.event}"' - f' and "{self.stop_event}",' - f' but notified of "{event}"') - if event == self.stop_event: - self.unlisten() - self.arg_dict.update(pid=pid, event=event, extra=extra) - self.callback(self.arg_dict) - if poll: - break - else: # we timed out - self.unlisten() - self.callback(None) - - -# The actual PostgreSQL database connection interface: - -class DB: - """Wrapper class for the _pg connection type.""" - - db: Connection | None = None # invalid fallback for underlying connection - _db_args: Any # either the connectoin args or the underlying connection - - def __init__(self, *args: Any, **kw: Any) -> None: - """Create a new connection. - - You can pass either the connection parameters or an existing - _pg or pgdb connection. This allows you to use the methods - of the classic pg interface with a DB-API 2 pgdb connection. - """ - if not args and len(kw) == 1: - db = kw.get('db') - elif not kw and len(args) == 1: - db = args[0] - else: - db = None - if db: - if isinstance(db, DB): - db = db.db - else: - with suppress(AttributeError): - # noinspection PyUnresolvedReferences - db = db._cnx - if not db or not hasattr(db, 'db') or not hasattr(db, 'query'): - db = connect(*args, **kw) - self._db_args = args, kw - self._closeable = True - else: - self._db_args = db - self._closeable = False - self.db = db - self.dbname = db.db - self._regtypes = False - self._attnames: dict[str, AttrDict] = {} - self._generated: dict[str, frozenset[str]] = {} - self._pkeys: dict[str, str | tuple[str, ...]] = {} - self._privileges: dict[tuple[str, str], bool] = {} - self.adapter = Adapter(self) - self.dbtypes = DbTypes(self) - self._query_attnames = ( - "SELECT a.attname," - " t.oid, t.typname, t.oid::pg_catalog.regtype," - " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" - " FROM pg_catalog.pg_attribute a" - " JOIN pg_catalog.pg_type t" - " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=)" - " {}::pg_catalog.regclass" - " AND {} AND NOT a.attisdropped ORDER BY a.attnum") - if db.server_version < 120000: - self._query_generated = ( - "a.attidentity OPERATOR(pg_catalog.=) 'a'" - ) - else: - self._query_generated = ( - "(a.attidentity OPERATOR(pg_catalog.=) 'a' OR" - " a.attgenerated OPERATOR(pg_catalog.!=) '')" - ) - db.set_cast_hook(self.dbtypes.typecast) - # For debugging scripts, self.debug can be set - # * to a string format specification (e.g. in CGI set to "%s
"), - # * to a file object to write debug statements or - # * to a callable object which takes a string argument - # * to any other true value to just print debug statements - self.debug: Any = None - - def __getattr__(self, name: str) -> Any: - """Get the specified attritbute of the connection.""" - # All undefined members are same as in underlying connection: - if self.db: - return getattr(self.db, name) - else: - raise _int_error('Connection is not valid') - - def __dir__(self) -> list[str]: - """List all attributes of the connection.""" - # Custom dir function including the attributes of the connection: - attrs = set(self.__class__.__dict__) - attrs.update(self.__dict__) - attrs.update(dir(self.db)) - return sorted(attrs) - - # Context manager methods - - def __enter__(self) -> DB: - """Enter the runtime context. This will start a transaction.""" - self.begin() - return self - - def __exit__(self, et: type[BaseException] | None, - ev: BaseException | None, tb: Any) -> None: - """Exit the runtime context. This will end the transaction.""" - if et is None and ev is None and tb is None: - self.commit() - else: - self.rollback() - - def __del__(self) -> None: - """Delete the connection.""" - try: - db = self.db - except AttributeError: - db = None - if db: - with suppress(TypeError): # when already closed - db.set_cast_hook(None) - if self._closeable: - with suppress(InternalError): # when already closed - db.close() - - # Auxiliary methods - - def _do_debug(self, *args: Any) -> None: - """Print a debug message.""" - if self.debug: - s = '\n'.join(str(arg) for arg in args) - if isinstance(self.debug, str): - print(self.debug % s) - elif hasattr(self.debug, 'write'): - # noinspection PyCallingNonCallable - self.debug.write(s + '\n') - elif callable(self.debug): - self.debug(s) - else: - print(s) - - def _escape_qualified_name(self, s: str) -> str: - """Escape a qualified name. - - Escapes the name for use as an SQL identifier, unless the - name contains a dot, in which case the name is ambiguous - (could be a qualified name or just a name with a dot in it) - and must be quoted manually by the caller. - """ - if '.' not in s: - s = self.escape_identifier(s) - return s - - @staticmethod - def _make_bool(d: Any) -> bool | str: - """Get boolean value corresponding to d.""" - return bool(d) if get_bool() else ('t' if d else 'f') - - @staticmethod - def _list_params(params: Sequence) -> str: - """Create a human readable parameter list.""" - return ', '.join(f'${n}={v!r}' for n, v in enumerate(params, 1)) - - @property - def _valid_db(self) -> Connection: - """Get underlying connection and make sure it is not closed.""" - db = self.db - if not db: - raise _int_error('Connection already closed') - return db - - # Public methods - - # escape_string and escape_bytea exist as methods, - # so we define unescape_bytea as a method as well - unescape_bytea = staticmethod(unescape_bytea) - - @staticmethod - def decode_json(s: str) -> Any: - """Decode a JSON string coming from the database.""" - return (get_jsondecode() or jsondecode)(s) - - @staticmethod - def encode_json(d: Any) -> str: - """Encode a JSON string for use within SQL.""" - return jsonencode(d) - - def close(self) -> None: - """Close the database connection.""" - # Wraps shared library function so we can track state. - db = self._valid_db - with suppress(TypeError): # when already closed - db.set_cast_hook(None) - if self._closeable: - db.close() - self.db = None - - def reset(self) -> None: - """Reset connection with current parameters. - - All derived queries and large objects derived from this connection - will not be usable after this call. - """ - self._valid_db.reset() - - def reopen(self) -> None: - """Reopen connection to the database. - - Used in case we need another connection to the same database. - Note that we can still reopen a database that we have closed. - """ - # There is no such shared library function. - if self._closeable: - args, kw = self._db_args - db = connect(*args, **kw) - if self.db: - self.db.set_cast_hook(None) - self.db.close() - db.set_cast_hook(self.dbtypes.typecast) - self.db = db - else: - self.db = self._db_args - - def begin(self, mode: str | None = None) -> Query: - """Begin a transaction.""" - qstr = 'BEGIN' - if mode: - qstr += ' ' + mode - return self.query(qstr) - - start = begin - - def commit(self) -> Query: - """Commit the current transaction.""" - return self.query('COMMIT') - - end = commit - - def rollback(self, name: str | None = None) -> Query: - """Roll back the current transaction.""" - qstr = 'ROLLBACK' - if name: - qstr += ' TO ' + name - return self.query(qstr) - - abort = rollback - - def savepoint(self, name: str) -> Query: - """Define a new savepoint within the current transaction.""" - return self.query('SAVEPOINT ' + name) - - def release(self, name: str) -> Query: - """Destroy a previously defined savepoint.""" - return self.query('RELEASE ' + name) - - def get_parameter(self, - parameter: str | list[str] | tuple[str, ...] | - set[str] | frozenset[str] | dict[str, Any] - ) -> str | list[str] | dict[str, str]: - """Get the value of a run-time parameter. - - If the parameter is a string, the return value will also be a string - that is the current setting of the run-time parameter with that name. - - You can get several parameters at once by passing a list, set or dict. - When passing a list of parameter names, the return value will be a - corresponding list of parameter settings. When passing a set of - parameter names, a new dict will be returned, mapping these parameter - names to their settings. Finally, if you pass a dict as parameter, - its values will be set to the current parameter settings corresponding - to its keys. - - By passing the special name 'all' as the parameter, you can get a dict - of all existing configuration parameters. - """ - values: Any - if isinstance(parameter, str): - parameter = [parameter] - values = None - elif isinstance(parameter, (list, tuple)): - values = [] - elif isinstance(parameter, (set, frozenset)): - values = {} - elif isinstance(parameter, dict): - values = parameter - else: - raise TypeError( - 'The parameter must be a string, list, set or dict') - if not parameter: - raise TypeError('No parameter has been specified') - query = self._valid_db.query - params: Any = {} if isinstance(values, dict) else [] - for param_key in parameter: - param = param_key.strip().lower() if isinstance( - param_key, (bytes, str)) else None - if not param: - raise TypeError('Invalid parameter') - if param == 'all': - cmd = 'SHOW ALL' - values = query(cmd).getresult() - values = {value[0]: value[1] for value in values} - break - if isinstance(params, dict): - params[param] = param_key - else: - params.append(param) - else: - for param in params: - cmd = f'SHOW {param}' - value = query(cmd).singlescalar() - if values is None: - values = value - elif isinstance(values, list): - values.append(value) - else: - values[params[param]] = value - return values - - def set_parameter(self, - parameter: str | list[str] | tuple[str, ...] | - set[str] | frozenset[str] | dict[str, Any], - value: str | list[str] | tuple[str, ...] | - set[str] | frozenset[str]| None = None, - local: bool = False) -> None: - """Set the value of a run-time parameter. - - If the parameter and the value are strings, the run-time parameter - will be set to that value. If no value or None is passed as a value, - then the run-time parameter will be restored to its default value. - - You can set several parameters at once by passing a list of parameter - names, together with a single value that all parameters should be - set to or with a corresponding list of values. You can also pass - the parameters as a set if you only provide a single value. - Finally, you can pass a dict with parameter names as keys. In this - case, you should not pass a value, since the values for the parameters - will be taken from the dict. - - By passing the special name 'all' as the parameter, you can reset - all existing settable run-time parameters to their default values. - - If you set local to True, then the command takes effect for only the - current transaction. After commit() or rollback(), the session-level - setting takes effect again. Setting local to True will appear to - have no effect if it is executed outside a transaction, since the - transaction will end immediately. - """ - if isinstance(parameter, str): - parameter = {parameter: value} - elif isinstance(parameter, (list, tuple)): - if isinstance(value, (list, tuple)): - parameter = dict(zip(parameter, value)) - else: - parameter = dict.fromkeys(parameter, value) - elif isinstance(parameter, (set, frozenset)): - if isinstance(value, (list, tuple, set, frozenset)): - value = set(value) - if len(value) == 1: - value = next(iter(value)) - if not (value is None or isinstance(value, str)): - raise ValueError( - 'A single value must be specified' - ' when parameter is a set') - parameter = dict.fromkeys(parameter, value) - elif isinstance(parameter, dict): - if value is not None: - raise ValueError( - 'A value must not be specified' - ' when parameter is a dictionary') - else: - raise TypeError( - 'The parameter must be a string, list, set or dict') - if not parameter: - raise TypeError('No parameter has been specified') - params: dict[str, str | None] = {} - for param, param_value in parameter.items(): - param = param.strip().lower() if isinstance(param, str) else None - if not param: - raise TypeError('Invalid parameter') - if param == 'all': - if param_value is not None: - raise ValueError( - 'A value must not be specified' - " when parameter is 'all'") - params = {'all': None} - break - params[param] = param_value - local_clause = ' LOCAL' if local else '' - for param, param_value in params.items(): - cmd = (f'RESET{local_clause} {param}' - if param_value is None else - f'SET{local_clause} {param} TO {param_value}') - self._do_debug(cmd) - self._valid_db.query(cmd) - - def query(self, command: str, *args: Any) -> Query: - """Execute a SQL command string. - - This method simply sends a SQL query to the database. If the query is - an insert statement that inserted exactly one row into a table that - has OIDs, the return value is the OID of the newly inserted row. - If the query is an update or delete statement, or an insert statement - that did not insert exactly one row in a table with OIDs, then the - number of rows affected is returned as a string. If it is a statement - that returns rows as a result (usually a select statement, but maybe - also an "insert/update ... returning" statement), this method returns - a Query object that can be accessed via getresult() or dictresult() - or simply printed. Otherwise, it returns `None`. - - The query can contain numbered parameters of the form $1 in place - of any data constant. Arguments given after the query string will - be substituted for the corresponding numbered parameter. Parameter - values can also be given as a single list or tuple argument. - """ - # Wraps shared library function for debugging. - db = self._valid_db - if args: - self._do_debug(command, args) - return db.query(command, args) - self._do_debug(command) - return db.query(command) - - def query_formatted(self, command: str, - parameters: tuple | list | dict | None = None, - types: tuple | list | dict | None = None, - inline: bool =False) -> Query: - """Execute a formatted SQL command string. - - Similar to query, but using Python format placeholders of the form - %s or %(names)s instead of PostgreSQL placeholders of the form $1. - The parameters must be passed as a tuple, list or dict. You can - also pass a corresponding tuple, list or dict of database types in - order to format the parameters properly in case there is ambiguity. - - If you set inline to True, the parameters will be sent to the database - embedded in the SQL command, otherwise they will be sent separately. - """ - return self.query(*self.adapter.format_query( - command, parameters, types, inline)) - - def query_prepared(self, name: str, *args: Any) -> Query: - """Execute a prepared SQL statement. - - This works like the query() method, except that instead of passing - the SQL command, you pass the name of a prepared statement. If you - pass an empty name, the unnamed statement will be executed. - """ - if name is None: - name = '' - db = self._valid_db - if args: - self._do_debug('EXECUTE', name, args) - return db.query_prepared(name, args) - self._do_debug('EXECUTE', name) - return db.query_prepared(name) - - def prepare(self, name: str, command: str) -> None: - """Create a prepared SQL statement. - - This creates a prepared statement for the given command with the - given name for later execution with the query_prepared() method. - - The name can be empty to create an unnamed statement, in which case - any pre-existing unnamed statement is automatically replaced; - otherwise it is an error if the statement name is already - defined in the current database session. We recommend always using - named queries, since unnamed queries have a limited lifetime and - can be automatically replaced or destroyed by various operations. - """ - if name is None: - name = '' - self._do_debug('prepare', name, command) - self._valid_db.prepare(name, command) - - def describe_prepared(self, name: str | None = None) -> Query: - """Describe a prepared SQL statement. - - This method returns a Query object describing the result columns of - the prepared statement with the given name. If you omit the name, - the unnamed statement will be described if you created one before. - """ - if name is None: - name = '' - return self._valid_db.describe_prepared(name) - - def delete_prepared(self, name: str | None = None) -> Query: - """Delete a prepared SQL statement. - - This deallocates a previously prepared SQL statement with the given - name, or deallocates all prepared statements if you do not specify a - name. Note that prepared statements are also deallocated automatically - when the current session ends. - """ - if not name: - name = 'ALL' - cmd = f"DEALLOCATE {name}" - self._do_debug(cmd) - return self._valid_db.query(cmd) - - def pkey(self, table: str, composite: bool = False, flush: bool = False - ) -> str | tuple[str, ...]: - """Get the primary key of a table. - - Single primary keys are returned as strings unless you - set the composite flag. Composite primary keys are always - represented as tuples. Note that this raises a KeyError - if the table does not have a primary key. - - If flush is set then the internal cache for primary keys will - be flushed. This may be necessary after the database schema or - the search path has been changed. - """ - pkeys = self._pkeys - if flush: - pkeys.clear() - self._do_debug('The pkey cache has been flushed') - try: # cache lookup - pkey = pkeys[table] - except KeyError as e: # cache miss, check the database - cmd = ("SELECT" # noqa: S608 - " a.attname, a.attnum, i.indkey" - " FROM pg_catalog.pg_index i" - " JOIN pg_catalog.pg_attribute a" - " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" - " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" - " AND NOT a.attisdropped" - " WHERE i.indrelid OPERATOR(pg_catalog.=)" - " {}::pg_catalog.regclass" - " AND i.indisprimary ORDER BY a.attnum").format( - _quote_if_unqualified('$1', table)) - res = self._valid_db.query(cmd, (table,)).getresult() - if not res: - raise KeyError(f'Table {table} has no primary key') from e - # we want to use the order defined in the primary key index here, - # not the order as defined by the columns in the table - if len(res) > 1: - indkey = res[0][2] - pkey = tuple(row[0] for row in sorted( - res, key=lambda row: indkey.index(row[1]))) - else: - pkey = res[0][0] - pkeys[table] = pkey # cache it - if composite and not isinstance(pkey, tuple): - pkey = (pkey,) - return pkey - - def pkeys(self, table: str) -> tuple[str, ...]: - """Get the primary key of a table as a tuple. - - Same as pkey() with 'composite' set to True. - """ - return self.pkey(table, True) # type: ignore - - def get_databases(self) -> list[str]: - """Get list of databases in the system.""" - return [r[0] for r in self._valid_db.query( - 'SELECT datname FROM pg_catalog.pg_database').getresult()] - - def get_relations(self, kinds: str | Sequence[str] | None = None, - system: bool = False) -> list[str]: - """Get list of relations in connected database of specified kinds. - - If kinds is None or empty, all kinds of relations are returned. - Otherwise, kinds can be a string or sequence of type letters - specifying which kind of relations you want to list. - - Set the system flag if you want to get the system relations as well. - """ - where_parts = [] - if kinds: - where_parts.append( - "r.relkind IN ({})".format(','.join(f"'{k}'" for k in kinds))) - if not system: - where_parts.append("s.nspname NOT SIMILAR" - " TO 'pg/_%|information/_schema' ESCAPE '/'") - where = " WHERE " + ' AND '.join(where_parts) if where_parts else '' - cmd = ("SELECT" # noqa: S608 - " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" - " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" - " FROM pg_catalog.pg_class r" - " JOIN pg_catalog.pg_namespace s" - f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" - " ORDER BY s.nspname, r.relname") - return [r[0] for r in self._valid_db.query(cmd).getresult()] - - def get_tables(self, system: bool = False) -> list[str]: - """Return list of tables in connected database. - - Set the system flag if you want to get the system tables as well. - """ - return self.get_relations('r', system) - - def get_attnames(self, table: str, with_oid: bool=True, flush: bool=False - ) -> AttrDict: - """Given the name of a table, dig out the set of attribute names. - - Returns a read-only dictionary of attribute names (the names are - the keys, the values are the names of the attributes' types) - with the column names in the proper order if you iterate over it. - - If flush is set, then the internal cache for attribute names will - be flushed. This may be necessary after the database schema or - the search path has been changed. - - By default, only a limited number of simple types will be returned. - You can get the registered types after calling use_regtypes(True). - """ - attnames = self._attnames - if flush: - attnames.clear() - self._do_debug('The attnames cache has been flushed') - try: # cache lookup - names = attnames[table] - except KeyError: # cache miss, check the database - cmd = "a.attnum OPERATOR(pg_catalog.>) 0" - if with_oid: - cmd = f"({cmd} OR a.attname OPERATOR(pg_catalog.=) 'oid')" - cmd = self._query_attnames.format( - _quote_if_unqualified('$1', table), cmd) - res = self._valid_db.query(cmd, (table,)).getresult() - types = self.dbtypes - names = AttrDict((name[0], types.add(*name[1:])) for name in res) - attnames[table] = names # cache it - return names - - def get_generated(self, table: str, flush: bool = False) -> frozenset[str]: - """Given the name of a table, dig out the set of generated columns. - - Returns a set of column names that are generated and unalterable. - - If flush is set, then the internal cache for generated columns will - be flushed. This may be necessary after the database schema or - the search path has been changed. - """ - generated = self._generated - if flush: - generated.clear() - self._do_debug('The generated cache has been flushed') - try: # cache lookup - names = generated[table] - except KeyError: # cache miss, check the database - cmd = "a.attnum OPERATOR(pg_catalog.>) 0" - cmd = f"{cmd} AND {self._query_generated}" - cmd = self._query_attnames.format( - _quote_if_unqualified('$1', table), cmd) - res = self._valid_db.query(cmd, (table,)).getresult() - names = frozenset(name[0] for name in res) - generated[table] = names # cache it - return names - - def use_regtypes(self, regtypes: bool | None = None) -> bool: - """Use registered type names instead of simplified type names.""" - if regtypes is None: - return self.dbtypes._regtypes - regtypes = bool(regtypes) - if regtypes != self.dbtypes._regtypes: - self.dbtypes._regtypes = regtypes - self._attnames.clear() - self.dbtypes.clear() - return regtypes - - def has_table_privilege(self, table: str, privilege: str = 'select', - flush: bool = False) -> bool: - """Check whether current user has specified table privilege. - - If flush is set, then the internal cache for table privileges will - be flushed. This may be necessary after privileges have been changed. - """ - privileges = self._privileges - if flush: - privileges.clear() - self._do_debug('The privileges cache has been flushed') - privilege = privilege.lower() - try: # ask cache - ret = privileges[table, privilege] - except KeyError: # cache miss, ask the database - cmd = "SELECT pg_catalog.has_table_privilege({}, $2)".format( - _quote_if_unqualified('$1', table)) - query = self._valid_db.query(cmd, (table, privilege)) - ret = query.singlescalar() == self._make_bool(True) - privileges[table, privilege] = ret # cache it - return ret - - def get(self, table: str, row: Any, - keyname: str | tuple[str, ...] | None = None) -> dict[str, Any]: - """Get a row from a database table or view. - - This method is the basic mechanism to get a single row. It assumes - that the keyname specifies a unique row. It must be the name of a - single column or a tuple of column names. If the keyname is not - specified, then the primary key for the table is used. - - If row is a dictionary, then the value for the key is taken from it. - Otherwise, the row must be a single value or a tuple of values - corresponding to the passed keyname or primary key. The fetched row - from the table will be returned as a new dictionary or used to replace - the existing values when row was passed as a dictionary. - - The OID is also put into the dictionary if the table has one, but - in order to allow the caller to work with multiple tables, it is - munged as "oid(table)" using the actual name of the table. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - attnames = self.get_attnames(table) - qoid = _oid_key(table) if 'oid' in attnames else None - if keyname and isinstance(keyname, str): - keyname = (keyname,) - if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row: - row['oid'] = row[qoid] - if not keyname: - try: # if keyname is not specified, try using the primary key - keyname = self.pkeys(table) - except KeyError as e: # the table has no primary key - # try using the oid instead - if qoid and isinstance(row, dict) and 'oid' in row: - keyname = ('oid',) - else: - raise _prg_error( - f'Table {table} has no primary key') from e - else: # the table has a primary key - # check whether all key columns have values - if isinstance(row, dict) and not set(keyname).issubset(row): - # try using the oid instead - if qoid and 'oid' in row: - keyname = ('oid',) - else: - raise KeyError( - 'Missing value in row for specified keyname') - if not isinstance(row, dict): - if not isinstance(row, (tuple, list)): - row = [row] - if len(keyname) != len(row): - raise KeyError( - 'Differing number of items in keyname and row') - row = dict(zip(keyname, row)) - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - what = 'oid, *' if qoid else '*' - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keyname) - if 'oid' in row: - if qoid: - row[qoid] = row['oid'] - del row['oid'] - t = self._escape_qualified_name(table) - cmd = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s - self._do_debug(cmd, params) - query = self._valid_db.query(cmd, params) - res = query.dictresult() - if not res: - # make where clause in error message better readable - where = where.replace('OPERATOR(pg_catalog.=)', '=') - raise _db_error( - f'No such record in {table}\nwhere {where}\nwith ' - + self._list_params(params)) - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - return row - - def insert(self, table: str, row: dict[str, Any] | None = None, **kw: Any - ) -> dict[str, Any]: - """Insert a row into a database table. - - This method inserts a row into a table. The name of the table must - be passed as the first parameter. The other parameters are used for - providing the data of the row that shall be inserted into the table. - If a dictionary is supplied as the second parameter, it starts with - that. Otherwise, it uses a blank dictionary. - Either way the dictionary is updated from the keywords. - - The dictionary is then reloaded with the values actually inserted in - order to pick up values modified by rules, triggers, etc. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - if row is None: - row = {} - row.update(kw) - if 'oid' in row: - del row['oid'] # do not insert oid - attnames = self.get_attnames(table) - generated = self.get_generated(table) - qoid = _oid_key(table) if 'oid' in attnames else None - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - name_list, value_list = [], [] - for n in attnames: - if n in row and n not in generated: - name_list.append(col(n)) - value_list.append(adapt(row[n], attnames[n])) - if not name_list: - raise _prg_error('No column found that can be inserted') - names, values = ', '.join(name_list), ', '.join(value_list) - ret = 'oid, *' if qoid else '*' - t = self._escape_qualified_name(table) - cmd = (f'INSERT INTO {t} ({names})' # noqa: S608 - f' VALUES ({values}) RETURNING {ret}') - self._do_debug(cmd, params) - query = self._valid_db.query(cmd, params) - res = query.dictresult() - if res: # this should always be true - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - return row - - def update(self, table: str, row: dict[str, Any] | None = None, **kw : Any - ) -> dict[str, Any]: - """Update an existing row in a database table. - - Similar to insert, but updates an existing row. The update is based - on the primary key of the table or the OID value as munged by get() - or passed as keyword. The OID will take precedence if provided, so - that it is possible to update the primary key itself. - - The dictionary is then modified to reflect any changes caused by the - update due to triggers, rules, default values, etc. - """ - if table.endswith('*'): - table = table[:-1].rstrip() # need parent table name - attnames = self.get_attnames(table) - generated = self.get_generated(table) - qoid = _oid_key(table) if 'oid' in attnames else None - if row is None: - row = {} - elif 'oid' in row: - del row['oid'] # only accept oid key from named args for safety - row.update(kw) - if qoid and qoid in row and 'oid' not in row: - row['oid'] = row[qoid] - if qoid and 'oid' in row: # try using the oid - keynames: tuple[str, ...] = ('oid',) - keyset = set(keynames) - else: # try using the primary key - try: - keynames = self.pkeys(table) - except KeyError as e: # the table has no primary key - raise _prg_error(f'Table {table} has no primary key') from e - keyset = set(keynames) - # check whether all key columns have values - if not keyset.issubset(row): - raise KeyError('Missing value for primary key in row') - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keynames) - if 'oid' in row: - if qoid: - row[qoid] = row['oid'] - del row['oid'] - values_list = [] - for n in attnames: - if n in row and n not in keyset and n not in generated: - values_list.append(f'{col(n)} = {adapt(row[n], attnames[n])}') - if not values_list: - return row - values = ', '.join(values_list) - ret = 'oid, *' if qoid else '*' - t = self._escape_qualified_name(table) - cmd = (f'UPDATE {t} SET {values}' # noqa: S608 - f' WHERE {where} RETURNING {ret}') - self._do_debug(cmd, params) - query = self._valid_db.query(cmd, params) - res = query.dictresult() - if res: # may be empty when row does not exist - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - return row - - def upsert(self, table: str, row: dict[str, Any] | None = None, **kw: Any - ) -> dict[str, Any]: - """Insert a row into a database table with conflict resolution. - - This method inserts a row into a table, but instead of raising a - ProgrammingError exception in case a row with the same primary key - already exists, an update will be executed instead. This will be - performed as a single atomic operation on the database, so race - conditions can be avoided. - - Like the insert method, the first parameter is the name of the - table and the second parameter can be used to pass the values to - be inserted as a dictionary. - - Unlike the insert und update statement, keyword parameters are not - used to modify the dictionary, but to specify which columns shall - be updated in case of a conflict, and in which way: - - A value of False or None means the column shall not be updated, - a value of True means the column shall be updated with the value - that has been proposed for insertion, i.e. has been passed as value - in the dictionary. Columns that are not specified by keywords but - appear as keys in the dictionary are also updated like in the case - keywords had been passed with the value True. - - So if in the case of a conflict you want to update every column - that has been passed in the dictionary row, you would call - upsert(table, row). If you don't want to do anything in case - of a conflict, i.e. leave the existing row as it is, call - upsert(table, row, **dict.fromkeys(row)). - - If you need more fine-grained control of what gets updated, you can - also pass strings in the keyword parameters. These strings will - be used as SQL expressions for the update columns. In these - expressions you can refer to the value that already exists in - the table by prefixing the column name with "included.", and to - the value that has been proposed for insertion by prefixing the - column name with the "excluded." - - The dictionary is modified in any case to reflect the values in - the database after the operation has completed. - - Note: The method uses the PostgreSQL "upsert" feature which is - only available since PostgreSQL 9.5. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - if row is None: - row = {} - if 'oid' in row: - del row['oid'] # do not insert oid - if 'oid' in kw: - del kw['oid'] # do not update oid - attnames = self.get_attnames(table) - generated = self.get_generated(table) - qoid = _oid_key(table) if 'oid' in attnames else None - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - name_list, value_list = [], [] - for n in attnames: - if n in row and n not in generated: - name_list.append(col(n)) - value_list.append(adapt(row[n], attnames[n])) - names, values = ', '.join(name_list), ', '.join(value_list) - try: - keynames = self.pkeys(table) - except KeyError as e: - raise _prg_error(f'Table {table} has no primary key') from e - target = ', '.join(col(k) for k in keynames) - update = [] - keyset = set(keynames) - keyset.add('oid') - for n in attnames: - if n not in keyset and n not in generated: - value = kw.get(n, n in row) - if value: - if not isinstance(value, str): - value = f'excluded.{col(n)}' - update.append(f'{col(n)} = {value}') - if not values: - return row - do = 'update set ' + ', '.join(update) if update else 'nothing' - ret = 'oid, *' if qoid else '*' - t = self._escape_qualified_name(table) - cmd = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 - f' VALUES ({values})' - f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') - self._do_debug(cmd, params) - query = self._valid_db.query(cmd, params) - res = query.dictresult() - if res: # may be empty with "do nothing" - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - else: - self.get(table, row) - return row - - def clear(self, table: str, row: dict[str, Any] | None = None - ) -> dict[str, Any]: - """Clear all the attributes to values determined by the types. - - Numeric types are set to 0, Booleans are set to false, and everything - else is set to the empty string. If the row argument is present, - it is used as the row dictionary and any entries matching attribute - names are cleared with everything else left unchanged. - """ - # At some point we will need a way to get defaults from a table. - if row is None: - row = {} # empty if argument is not present - attnames = self.get_attnames(table) - for n, t in attnames.items(): - if n == 'oid': - continue - t = t.simple - if t in DbTypes._num_types: - row[n] = 0 - elif t == 'bool': - row[n] = self._make_bool(False) - else: - row[n] = '' - return row - - def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any - ) -> int: - """Delete an existing row in a database table. - - This method deletes the row from a table. It deletes based on the - primary key of the table or the OID value as munged by get() or - passed as keyword. The OID will take precedence if provided. - - The return value is the number of deleted rows (i.e. 0 if the row - did not exist and 1 if the row was deleted). - - Note that if the row cannot be deleted because e.g. it is still - referenced by another table, this method raises a ProgrammingError. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - attnames = self.get_attnames(table) - qoid = _oid_key(table) if 'oid' in attnames else None - if row is None: - row = {} - elif 'oid' in row: - del row['oid'] # only accept oid key from named args for safety - row.update(kw) - if qoid and qoid in row and 'oid' not in row: - row['oid'] = row[qoid] - if qoid and 'oid' in row: # try using the oid - keynames: tuple[str, ...] = ('oid',) - else: # try using the primary key - try: - keynames = self.pkeys(table) - except KeyError as e: # the table has no primary key - raise _prg_error(f'Table {table} has no primary key') from e - # check whether all key columns have values - if not set(keynames).issubset(row): - raise KeyError('Missing value for primary key in row') - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keynames) - if 'oid' in row: - if qoid: - row[qoid] = row['oid'] - del row['oid'] - t = self._escape_qualified_name(table) - cmd = f'DELETE FROM {t} WHERE {where}' # noqa: S608 - self._do_debug(cmd, params) - res = self._valid_db.query(cmd, params) - return int(res) # type: ignore - - def truncate(self, table: str | list[str] | tuple[str, ...] | - set[str] | frozenset[str], restart: bool = False, - cascade: bool = False, only: bool = False) -> Query: - """Empty a table or set of tables. - - This method quickly removes all rows from the given table or set - of tables. It has the same effect as an unqualified DELETE on each - table, but since it does not actually scan the tables it is faster. - Furthermore, it reclaims disk space immediately, rather than requiring - a subsequent VACUUM operation. This is most useful on large tables. - - If restart is set to True, sequences owned by columns of the truncated - table(s) are automatically restarted. If cascade is set to True, it - also truncates all tables that have foreign-key references to any of - the named tables. If the parameter 'only' is not set to True, all the - descendant tables (if any) will also be truncated. Optionally, a '*' - can be specified after the table name to explicitly indicate that - descendant tables are included. - """ - if isinstance(table, str): - table_only = {table: only} - table = [table] - elif isinstance(table, (list, tuple)): - if isinstance(only, (list, tuple)): - table_only = dict(zip(table, only)) - else: - table_only = dict.fromkeys(table, only) - elif isinstance(table, (set, frozenset)): - table_only = dict.fromkeys(table, only) - else: - raise TypeError('The table must be a string, list or set') - if not (restart is None or isinstance(restart, (bool, int))): - raise TypeError('Invalid type for the restart option') - if not (cascade is None or isinstance(cascade, (bool, int))): - raise TypeError('Invalid type for the cascade option') - tables = [] - for t in table: - u = table_only.get(t) - if not (u is None or isinstance(u, (bool, int))): - raise TypeError('Invalid type for the only option') - if t.endswith('*'): - if u: - raise ValueError( - 'Contradictory table name and only options') - t = t[:-1].rstrip() - t = self._escape_qualified_name(t) - if u: - t = f'ONLY {t}' - tables.append(t) - cmd_parts = ['TRUNCATE', ', '.join(tables)] - if restart: - cmd_parts.append('RESTART IDENTITY') - if cascade: - cmd_parts.append('CASCADE') - cmd = ' '.join(cmd_parts) - self._do_debug(cmd) - return self._valid_db.query(cmd) - - def get_as_list( - self, table: str, - what: str | list[str] | tuple[str, ...] | None = None, - where: str | list[str] | tuple[str, ...] | None = None, - order: str | list[str] | tuple[str, ...] | bool | None = None, - limit: int | None = None, offset: int | None = None, - scalar: bool = False) -> list: - """Get a table as a list. - - This gets a convenient representation of the table as a list - of named tuples in Python. You only need to pass the name of - the table (or any other SQL expression returning rows). Note that - by default this will return the full content of the table which - can be huge and overflow your memory. However, you can control - the amount of data returned using the other optional parameters. - - The parameter 'what' can restrict the query to only return a - subset of the table columns. It can be a string, list or a tuple. - - The parameter 'where' can restrict the query to only return a - subset of the table rows. It can be a string, list or a tuple - of SQL expressions that all need to be fulfilled. - - The parameter 'order' specifies the ordering of the rows. It can - also be a string, list or a tuple. If no ordering is specified, - the result will be ordered by the primary key(s) or all columns if - no primary key exists. You can set 'order' to False if you don't - care about the ordering. The parameters 'limit' and 'offset' can be - integers specifying the maximum number of rows returned and a number - of rows skipped over. - - If you set the 'scalar' option to True, then instead of the - named tuples you will get the first items of these tuples. - This is useful if the result has only one column anyway. - """ - if not table: - raise TypeError('The table name is missing') - if what: - if isinstance(what, (list, tuple)): - what = ', '.join(map(str, what)) - if order is None: - order = what - else: - what = '*' - cmd_parts = ['SELECT', what, 'FROM', table] - if where: - if isinstance(where, (list, tuple)): - where = ' AND '.join(map(str, where)) - cmd_parts.extend(['WHERE', where]) - if order is None or order is True: - try: - order = self.pkeys(table) - except (KeyError, ProgrammingError): - with suppress(KeyError, ProgrammingError): - order = list(self.get_attnames(table)) - if order and not isinstance(order, bool): - if isinstance(order, (list, tuple)): - order = ', '.join(map(str, order)) - cmd_parts.extend(['ORDER BY', order]) - if limit: - cmd_parts.append(f'LIMIT {limit}') - if offset: - cmd_parts.append(f'OFFSET {offset}') - cmd = ' '.join(cmd_parts) - self._do_debug(cmd) - query = self._valid_db.query(cmd) - res = query.namedresult() - if res and scalar: - res = [row[0] for row in res] - return res - - def get_as_dict( - self, table: str, - keyname: str | list[str] | tuple[str, ...] | None = None, - what: str | list[str] | tuple[str, ...] | None = None, - where: str | list[str] | tuple[str, ...] | None = None, - order: str | list[str] | tuple[str, ...] | bool | None = None, - limit: int | None = None, offset: int | None = None, - scalar: bool = False) -> dict: - """Get a table as a dictionary. - - This method is similar to get_as_list(), but returns the table - as a Python dict instead of a Python list, which can be even - more convenient. The primary key column(s) of the table will - be used as the keys of the dictionary, while the other column(s) - will be the corresponding values. The keys will be named tuples - if the table has a composite primary key. The rows will be also - named tuples unless the 'scalar' option has been set to True. - With the optional parameter 'keyname' you can specify an alternative - set of columns to be used as the keys of the dictionary. It must - be set as a string, list or a tuple. - - The dictionary will be ordered using the order specified with the - 'order' parameter or the key column(s) if not specified. You can - set 'order' to False if you don't care about the ordering. - """ - if not table: - raise TypeError('The table name is missing') - if not keyname: - try: - keyname = self.pkeys(table) - except (KeyError, ProgrammingError) as e: - raise _prg_error(f'Table {table} has no primary key') from e - if isinstance(keyname, str): - keynames: list[str] | tuple[str, ...] = (keyname,) - elif isinstance(keyname, (list, tuple)): - keynames = keyname - else: - raise KeyError('The keyname must be a string, list or tuple') - if what: - if isinstance(what, (list, tuple)): - what = ', '.join(map(str, what)) - if order is None: - order = what - else: - what = '*' - cmd_parts = ['SELECT', what, 'FROM', table] - if where: - if isinstance(where, (list, tuple)): - where = ' AND '.join(map(str, where)) - cmd_parts.extend(['WHERE', where]) - if order is None or order is True: - order = keyname - if order and not isinstance(order, bool): - if isinstance(order, (list, tuple)): - order = ', '.join(map(str, order)) - cmd_parts.extend(['ORDER BY', order]) - if limit: - cmd_parts.append(f'LIMIT {limit}') - if offset: - cmd_parts.append(f'OFFSET {offset}') - cmd = ' '.join(cmd_parts) - self._do_debug(cmd) - query = self._valid_db.query(cmd) - res = query.getresult() - if not res: - return {} - keyset = set(keynames) - fields = query.listfields() - if not keyset.issubset(fields): - raise KeyError('Missing keyname in row') - key_index: list[int] = [] - row_index: list[int] = [] - for i, f in enumerate(fields): - (key_index if f in keyset else row_index).append(i) - key_tuple = len(key_index) > 1 - get_key = itemgetter(*key_index) - keys = map(get_key, res) - if scalar: - row_index = row_index[:1] - row_is_tuple = False - else: - row_is_tuple = len(row_index) > 1 - if scalar or row_is_tuple: - get_row: Callable[[tuple], tuple] = itemgetter( # pyright: ignore - *row_index) - else: - frst_index = row_index[0] - - def get_row(row : tuple) -> tuple: - return row[frst_index], # tuple with one item - - row_is_tuple = True - rows = map(get_row, res) - if key_tuple or row_is_tuple: - if key_tuple: - keys = _namediter(_MemoryQuery(keys, keynames)) # type: ignore - if row_is_tuple: - fields = tuple(f for f in fields if f not in keyset) - rows = _namediter(_MemoryQuery(rows, fields)) # type: ignore - # noinspection PyArgumentList - return dict(zip(keys, rows)) - - def notification_handler(self, event: str, callback: Callable, - arg_dict: dict | None = None, - timeout: int | float | None = None, - stop_event: str | None = None - ) -> NotificationHandler: - """Get notification handler that will run the given callback.""" - return NotificationHandler(self, event, callback, - arg_dict, timeout, stop_event) - - -# if run as script, print some information - -if __name__ == '__main__': - print('PyGreSQL version', version) - print() - print(__doc__) +init_core() diff --git a/pg/adapt.py b/pg/adapt.py new file mode 100644 index 00000000..fd4705ae --- /dev/null +++ b/pg/adapt.py @@ -0,0 +1,680 @@ +"""Adaption of parameters.""" + +from __future__ import annotations + +import weakref +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from json import dumps as jsonencode +from math import isinf, isnan +from re import compile as regex +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Sequence +from uuid import UUID + +from .attrs import AttrDict +from .cast import Typecasts +from .core import InterfaceError, ProgrammingError +from .helpers import quote_if_unqualified + +if TYPE_CHECKING: + from .db import DB + +__all__ = [ + 'Adapter', 'Bytea', 'DbType', 'DbTypes', + 'Hstore', 'Literal', 'Json', 'UUID' +] + + +class Bytea(bytes): + """Wrapper class for marking Bytea values.""" + + +class Hstore(dict): + """Wrapper class for marking hstore values.""" + + _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') + + @classmethod + def _quote(cls, s: Any) -> str: + if s is None: + return 'NULL' + if not isinstance(s, str): + s = str(s) + if not s: + return '""' + s = s.replace('"', '\\"') + if cls._re_quote.search(s): + s = f'"{s}"' + return s + + def __str__(self) -> str: + """Create a printable representation of the hstore value.""" + q = self._quote + return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) + + +class Json: + """Wrapper class for marking Json values.""" + + def __init__(self, obj: Any, + encode: Callable[[Any], str] | None = None) -> None: + """Initialize the JSON object.""" + self.obj = obj + self.encode = encode or jsonencode + + def __str__(self) -> str: + """Create a printable representation of the JSON object.""" + obj = self.obj + if isinstance(obj, str): + return obj + return self.encode(obj) + + +class Literal(str): + """Wrapper class for marking literal SQL values.""" + + + +class _SimpleTypes(dict): + """Dictionary mapping pg_type names to simple type names. + + The corresponding Python types and simple names are also mapped. + """ + + _type_aliases: Mapping[str, list[str | type]] = MappingProxyType({ + 'bool': [bool], + 'bytea': [Bytea], + 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', + 'abstime', 'reltime', # these are very old + 'datetime', 'timedelta', # these do not really exist + date, time, datetime, timedelta], + 'float': ['float4', 'float8', float], + 'int': ['cid', 'int2', 'int4', 'int8', 'oid', 'xid', int], + 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], + 'num': ['numeric', Decimal], 'money': [], + 'text': ['bpchar', 'char', 'name', 'varchar', bytes, str] + }) + + # noinspection PyMissingConstructor + def __init__(self) -> None: + """Initialize type mapping.""" + for typ, keys in self._type_aliases.items(): + keys = [typ, *keys] + for key in keys: + self[key] = typ + if isinstance(key, str): + self[f'_{key}'] = f'{typ}[]' + elif not isinstance(key, tuple): + self[List[key]] = f'{typ}[]' # type: ignore + + @staticmethod + def __missing__(key: str) -> str: + """Unmapped types are interpreted as text.""" + return 'text' + + def get_type_dict(self) -> dict[type, str]: + """Get a plain dictionary of only the types.""" + return {key: typ for key, typ in self.items() + if not isinstance(key, (str, tuple))} + + +_simpletypes = _SimpleTypes() +_simple_type_dict = _simpletypes.get_type_dict() + + +class _ParameterList(list): + """Helper class for building typed parameter lists.""" + + adapt: Callable + + def add(self, value: Any, typ:Any = None) -> str: + """Typecast value with known database type and build parameter list. + + If this is a literal value, it will be returned as is. Otherwise, a + placeholder will be returned and the parameter list will be augmented. + """ + # noinspection PyUnresolvedReferences + value = self.adapt(value, typ) + if isinstance(value, Literal): + return value + self.append(value) + return f'${len(self)}' + + + +class DbType(str): + """Class augmenting the simple type name with additional info. + + The following additional information is provided: + + oid: the PostgreSQL type OID + pgtype: the internal PostgreSQL data type name + regtype: the registered PostgreSQL data type name + simple: the more coarse-grained PyGreSQL type name + typlen: the internal size, negative if variable + typtype: b = base type, c = composite type etc. + category: A = Array, b = Boolean, C = Composite etc. + delim: delimiter for array types + relid: corresponding table for composite types + attnames: attributes for composite types + """ + + oid: int + pgtype: str + regtype: str + simple: str + typlen: int + typtype: str + category: str + delim: str + relid: int + + _get_attnames: Callable[[DbType], AttrDict] + + @property + def attnames(self) -> AttrDict: + """Get names and types of the fields of a composite type.""" + # noinspection PyUnresolvedReferences + return self._get_attnames(self) + + +class DbTypes(dict): + """Cache for PostgreSQL data types. + + This cache maps type OIDs and names to DbType objects containing + information on the associated database type. + """ + + _num_types = frozenset('int float num money int2 int4 int8' + ' float4 float8 numeric money'.split()) + + def __init__(self, db: DB) -> None: + """Initialize type cache for connection.""" + super().__init__() + self._db = weakref.proxy(db) + self._regtypes = False + self._typecasts = Typecasts() + self._typecasts.get_attnames = self.get_attnames # type: ignore + self._typecasts.connection = self._db.db + self._query_pg_type = ( + "SELECT oid, typname, oid::pg_catalog.regtype," + " typlen, typtype, typcategory, typdelim, typrelid" + " FROM pg_catalog.pg_type" + " WHERE oid OPERATOR(pg_catalog.=) {}::pg_catalog.regtype") + + def add(self, oid: int, pgtype: str, regtype: str, + typlen: int, typtype: str, category: str, delim: str, relid: int + ) -> DbType: + """Create a PostgreSQL type name with additional info.""" + if oid in self: + return self[oid] + simple = 'record' if relid else _simpletypes[pgtype] + typ = DbType(regtype if self._regtypes else simple) + typ.oid = oid + typ.simple = simple + typ.pgtype = pgtype + typ.regtype = regtype + typ.typlen = typlen + typ.typtype = typtype + typ.category = category + typ.delim = delim + typ.relid = relid + typ._get_attnames = self.get_attnames # type: ignore + return typ + + def __missing__(self, key: int | str) -> DbType: + """Get the type info from the database if it is not cached.""" + try: + cmd = self._query_pg_type.format(quote_if_unqualified('$1', key)) + res = self._db.query(cmd, (key,)).getresult() + except ProgrammingError: + res = None + if not res: + raise KeyError(f'Type {key} could not be found') + res = res[0] + typ = self.add(*res) + self[typ.oid] = self[typ.pgtype] = typ + return typ + + def get(self, key: int | str, # type: ignore + default: DbType | None = None) -> DbType | None: + """Get the type even if it is not cached.""" + try: + return self[key] + except KeyError: + return default + + def get_attnames(self, typ: Any) -> AttrDict | None: + """Get names and types of the fields of a composite type.""" + if not isinstance(typ, DbType): + typ = self.get(typ) + if not typ: + return None + if not typ.relid: + return None + return self._db.get_attnames(typ.relid, with_oid=False) + + def get_typecast(self, typ: Any) -> Callable | None: + """Get the typecast function for the given database type.""" + return self._typecasts.get(typ) + + def set_typecast(self, typ: str | Sequence[str], cast: Callable) -> None: + """Set a typecast function for the specified database type(s).""" + self._typecasts.set(typ, cast) + + def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecast function for the specified database type(s).""" + self._typecasts.reset(typ) + + def typecast(self, value: Any, typ: str) -> Any: + """Cast the given value according to the given database type.""" + if value is None: + # for NULL values, no typecast is necessary + return None + if not isinstance(typ, DbType): + db_type = self.get(typ) + if db_type: + typ = db_type.pgtype + cast = self.get_typecast(typ) if typ else None + if not cast or cast is str: + # no typecast is necessary + return value + return cast(value) + + +class Adapter: + """Class providing methods for adapting parameters to the database.""" + + _bool_true_values = frozenset('t true 1 y yes on'.split()) + + _date_literals = frozenset( + 'current_date current_time' + ' current_timestamp localtime localtimestamp'.split()) + + _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$') + _re_record_quote = regex(r'[(,"\\]') + _re_array_escape = _re_record_escape = regex(r'(["\\])') + + def __init__(self, db: DB): + """Initialize the adapter object with the given connection.""" + self.db = weakref.proxy(db) + + @classmethod + def _adapt_bool(cls, v: Any) -> str | None: + """Adapt a boolean parameter.""" + if isinstance(v, str): + if not v: + return None + v = v.lower() in cls._bool_true_values + return 't' if v else 'f' + + @classmethod + def _adapt_date(cls, v: Any) -> Any: + """Adapt a date parameter.""" + if not v: + return None + if isinstance(v, str) and v.lower() in cls._date_literals: + return Literal(v) + return v + + @staticmethod + def _adapt_num(v: Any) -> Any: + """Adapt a numeric parameter.""" + if not v and v != 0: + return None + return v + + _adapt_int = _adapt_float = _adapt_money = _adapt_num + + def _adapt_bytea(self, v: Any) -> str: + """Adapt a bytea parameter.""" + return self.db.escape_bytea(v) + + def _adapt_json(self, v: Any) -> str | None: + """Adapt a json parameter.""" + if not v: + return None + if isinstance(v, str): + return v + if isinstance(v, Json): + return str(v) + return self.db.encode_json(v) + + def _adapt_hstore(self, v: Any) -> str | None: + """Adapt a hstore parameter.""" + if not v: + return None + if isinstance(v, str): + return v + if isinstance(v, Hstore): + return str(v) + if isinstance(v, dict): + return str(Hstore(v)) + raise TypeError(f'Hstore parameter {v} has wrong type') + + def _adapt_uuid(self, v: Any) -> str | None: + """Adapt a UUID parameter.""" + if not v: + return None + if isinstance(v, str): + return v + return str(v) + + @classmethod + def _adapt_text_array(cls, v: Any) -> str: + """Adapt a text type array parameter.""" + if isinstance(v, list): + adapt = cls._adapt_text_array + return '{' + ','.join(adapt(v) for v in v) + '}' + if v is None: + return 'null' + if not v: + return '""' + v = str(v) + if cls._re_array_quote.search(v): + v = cls._re_array_escape.sub(r'\\\1', v) + v = f'"{v}"' + return v + + _adapt_date_array = _adapt_text_array + + @classmethod + def _adapt_bool_array(cls, v: Any) -> str: + """Adapt a boolean array parameter.""" + if isinstance(v, list): + adapt = cls._adapt_bool_array + return '{' + ','.join(adapt(v) for v in v) + '}' + if v is None: + return 'null' + if isinstance(v, str): + if not v: + return 'null' + v = v.lower() in cls._bool_true_values + return 't' if v else 'f' + + @classmethod + def _adapt_num_array(cls, v: Any) -> str: + """Adapt a numeric array parameter.""" + if isinstance(v, list): + adapt = cls._adapt_num_array + v = '{' + ','.join(adapt(v) for v in v) + '}' + if not v and v != 0: + return 'null' + return str(v) + + _adapt_int_array = _adapt_float_array = _adapt_money_array = \ + _adapt_num_array + + def _adapt_bytea_array(self, v: Any) -> bytes: + """Adapt a bytea array parameter.""" + if isinstance(v, list): + return b'{' + b','.join( + self._adapt_bytea_array(v) for v in v) + b'}' + if v is None: + return b'null' + return self.db.escape_bytea(v).replace(b'\\', b'\\\\') + + def _adapt_json_array(self, v: Any) -> str: + """Adapt a json array parameter.""" + if isinstance(v, list): + adapt = self._adapt_json_array + return '{' + ','.join(adapt(v) for v in v) + '}' + if not v: + return 'null' + if not isinstance(v, str): + v = self.db.encode_json(v) + if self._re_array_quote.search(v): + v = self._re_array_escape.sub(r'\\\1', v) + v = f'"{v}"' + return v + + def _adapt_record(self, v: Any, typ: Any) -> str: + """Adapt a record parameter with given type.""" + typ = self.get_attnames(typ).values() + if len(typ) != len(v): + raise TypeError(f'Record parameter {v} has wrong size') + adapt = self.adapt + value = [] + for v, t in zip(v, typ): # noqa: B020 + v = adapt(v, t) + if v is None: + v = '' + else: + if isinstance(v, bytes): + v = v.decode('ascii') + elif not isinstance(v, str): + v = str(v) + if v: + if self._re_record_quote.search(v): + v = self._re_record_escape.sub(r'\\\1', v) + v = f'"{v}"' + else: + v = '""' + value.append(v) + v = ','.join(value) + return f'({v})' + + def adapt(self, value: Any, typ: Any = None) -> str: + """Adapt a value with known database type.""" + if value is not None and not isinstance(value, Literal): + if typ: + simple = self.get_simple_name(typ) + else: + typ = simple = self.guess_simple_type(value) or 'text' + pg_str = getattr(value, '__pg_str__', None) + if pg_str: + value = pg_str(typ) + if simple == 'text': + pass + elif simple == 'record': + if isinstance(value, tuple): + value = self._adapt_record(value, typ) + elif simple.endswith('[]'): + if isinstance(value, list): + adapt = getattr(self, f'_adapt_{simple[:-2]}_array') + value = adapt(value) + else: + adapt = getattr(self, f'_adapt_{simple}') + value = adapt(value) + return value + + @staticmethod + def simple_type(name: str) -> DbType: + """Create a simple database type with given attribute names.""" + typ = DbType(name) + typ.simple = name + return typ + + @staticmethod + def get_simple_name(typ: Any) -> str: + """Get the simple name of a database type.""" + if isinstance(typ, DbType): + # noinspection PyUnresolvedReferences + return typ.simple + return _simpletypes[typ] + + @staticmethod + def get_attnames(typ: Any) -> dict[str, dict[str, str]]: + """Get the attribute names of a composite database type.""" + if isinstance(typ, DbType): + return typ.attnames + return {} + + @classmethod + def guess_simple_type(cls, value: Any) -> str | None: + """Try to guess which database type the given value has.""" + # optimize for most frequent types + try: + return _simple_type_dict[type(value)] + except KeyError: + pass + if isinstance(value, (bytes, str)): + return 'text' + if isinstance(value, bool): + return 'bool' + if isinstance(value, int): + return 'int' + if isinstance(value, float): + return 'float' + if isinstance(value, Decimal): + return 'num' + if isinstance(value, (date, time, datetime, timedelta)): + return 'date' + if isinstance(value, Bytea): + return 'bytea' + if isinstance(value, Json): + return 'json' + if isinstance(value, Hstore): + return 'hstore' + if isinstance(value, UUID): + return 'uuid' + if isinstance(value, list): + return (cls.guess_simple_base_type(value) or 'text') + '[]' + if isinstance(value, tuple): + simple_type = cls.simple_type + guess = cls.guess_simple_type + + # noinspection PyUnusedLocal + def get_attnames(self: DbType) -> AttrDict: + return AttrDict((str(n + 1), simple_type(guess(v) or 'text')) + for n, v in enumerate(value)) + + typ = simple_type('record') + typ._get_attnames = get_attnames + return typ + return None + + @classmethod + def guess_simple_base_type(cls, value: Any) -> str | None: + """Try to guess the base type of a given array.""" + for v in value: + if isinstance(v, list): + typ = cls.guess_simple_base_type(v) + else: + typ = cls.guess_simple_type(v) + if typ: + return typ + return None + + def adapt_inline(self, value: Any, nested: bool=False) -> Any: + """Adapt a value that is put into the SQL and needs to be quoted.""" + if value is None: + return 'NULL' + if isinstance(value, Literal): + return value + if isinstance(value, Bytea): + value = self.db.escape_bytea(value).decode('ascii') + elif isinstance(value, (datetime, date, time, timedelta)): + value = str(value) + if isinstance(value, (bytes, str)): + value = self.db.escape_string(value) + return f"'{value}'" + if isinstance(value, bool): + return 'true' if value else 'false' + if isinstance(value, float): + if isinf(value): + return "'-Infinity'" if value < 0 else "'Infinity'" + if isnan(value): + return "'NaN'" + return value + if isinstance(value, (int, Decimal)): + return value + if isinstance(value, list): + q = self.adapt_inline + s = '[{}]' if nested else 'ARRAY[{}]' + return s.format(','.join(str(q(v, nested=True)) for v in value)) + if isinstance(value, tuple): + q = self.adapt_inline + return '({})'.format(','.join(str(q(v)) for v in value)) + if isinstance(value, Json): + value = self.db.escape_string(str(value)) + return f"'{value}'::json" + if isinstance(value, Hstore): + value = self.db.escape_string(str(value)) + return f"'{value}'::hstore" + pg_repr = getattr(value, '__pg_repr__', None) + if not pg_repr: + raise InterfaceError( + f'Do not know how to adapt type {type(value)}') + value = pg_repr() + if isinstance(value, (tuple, list)): + value = self.adapt_inline(value) + return value + + def parameter_list(self) -> _ParameterList: + """Return a parameter list for parameters with known database types. + + The list has an add(value, typ) method that will build up the + list and return either the literal value or a placeholder. + """ + params = _ParameterList() + params.adapt = self.adapt + return params + + def format_query(self, command: str, + values: list | tuple | dict | None = None, + types: list | tuple | dict | None = None, + inline: bool=False + ) -> tuple[str, _ParameterList]: + """Format a database query using the given values and types. + + The optional types describe the values and must be passed as a list, + tuple or string (that will be split on whitespace) when values are + passed as a list or tuple, or as a dict if values are passed as a dict. + + If inline is set to True, then parameters will be passed inline + together with the query string. + """ + params = self.parameter_list() + if not values: + return command, params + if inline and types: + raise ValueError('Typed parameters must be sent separately') + if isinstance(values, (list, tuple)): + if inline: + adapt = self.adapt_inline + seq_literals = [adapt(value) for value in values] + else: + add = params.add + if types: + if isinstance(types, str): + types = types.split() + if (not isinstance(types, (list, tuple)) + or len(types) != len(values)): + raise TypeError('The values and types do not match') + seq_literals = [add(value, typ) + for value, typ in zip(values, types)] + else: + seq_literals = [add(value) for value in values] + command %= tuple(seq_literals) + elif isinstance(values, dict): + # we want to allow extra keys in the dictionary, + # so we first must find the values actually used in the command + used_values = {} + map_literals = dict.fromkeys(values, '') + for key in values: + del map_literals[key] + try: + command % map_literals + except KeyError: + used_values[key] = values[key] # pyright: ignore + map_literals[key] = '' + if inline: + adapt = self.adapt_inline + map_literals = {key: adapt(value) + for key, value in used_values.items()} + else: + add = params.add + if types: + if not isinstance(types, dict): + raise TypeError('The values and types do not match') + map_literals = {key: add(used_values[key], types.get(key)) + for key in sorted(used_values)} + else: + map_literals = {key: add(used_values[key]) + for key in sorted(used_values)} + command %= map_literals + else: + raise TypeError('The values must be passed as tuple, list or dict') + return command, params diff --git a/pg/attrs.py b/pg/attrs.py new file mode 100644 index 00000000..7a5e6c41 --- /dev/null +++ b/pg/attrs.py @@ -0,0 +1,35 @@ +"""Helpers for memorizing attributes.""" + +from typing import Any + +__all__ = ['AttrDict'] + + +class AttrDict(dict): + """Simple read-only ordered dictionary for storing attribute names.""" + + def __init__(self, *args: Any, **kw: Any) -> None: + """Initialize the dictionary.""" + self._read_only = False + super().__init__(*args, **kw) + self._read_only = True + error = self._read_only_error + self.clear = self.update = error # type: ignore + self.pop = self.setdefault = self.popitem = error # type: ignore + + def __setitem__(self, key: str, value: Any) -> None: + """Set a value.""" + if self._read_only: + self._read_only_error() + super().__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + """Delete a value.""" + if self._read_only: + self._read_only_error() + super().__delitem__(key) + + @staticmethod + def _read_only_error(*_args: Any, **_kw: Any) -> Any: + """Raise error for write operations.""" + raise TypeError('This object is read-only') diff --git a/pg/cast.py b/pg/cast.py new file mode 100644 index 00000000..ad1758be --- /dev/null +++ b/pg/cast.py @@ -0,0 +1,436 @@ +"""Typecasting mechanisms.""" + +from __future__ import annotations + +from collections import namedtuple +from datetime import date, datetime, timedelta +from functools import partial +from inspect import signature +from re import compile as regex +from typing import Any, Callable, ClassVar, Sequence +from uuid import UUID + +from .attrs import AttrDict +from .core import ( + Connection, + cast_array, + cast_hstore, + cast_record, + get_bool, + get_decimal, + get_decimal_point, + get_jsondecode, + unescape_bytea, +) +from .tz import timezone_as_offset + +__all__ = [ + 'cast_bool', 'cast_json', 'cast_num', 'cast_money', 'cast_int2vector', + 'cast_date', 'cast_time', 'cast_timetz', 'cast_interval', + 'cast_timestamp','cast_timestamptz', + 'Typecasts', 'get_typecast', 'set_typecast' +] + +def get_args(func: Callable) -> list: + """Get the arguments of a function.""" + return list(signature(func).parameters) + + +def cast_bool(value: str) -> Any: + """Cast a boolean value.""" + if not get_bool(): + return value + return value[0] == 't' + + +def cast_json(value: str) -> Any: + """Cast a JSON value.""" + cast = get_jsondecode() + if not cast: + return value + return cast(value) + + +def cast_num(value: str) -> Any: + """Cast a numeric value.""" + return (get_decimal() or float)(value) + + +def cast_money(value: str) -> Any: + """Cast a money value.""" + point = get_decimal_point() + if not point: + return value + if point != '.': + value = value.replace(point, '.') + value = value.replace('(', '-') + value = ''.join(c for c in value if c.isdigit() or c in '.-') + return (get_decimal() or float)(value) + + +def cast_int2vector(value: str) -> list[int]: + """Cast an int2vector value.""" + return [int(v) for v in value.split()] + + +def cast_date(value: str, connection: Connection) -> Any: + """Cast a date value.""" + # The output format depends on the server setting DateStyle. The default + # setting ISO and the setting for German are actually unambiguous. The + # order of days and months in the other two settings is however ambiguous, + # so at least here we need to consult the setting to properly parse values. + if value == '-infinity': + return date.min + if value == 'infinity': + return date.max + values = value.split() + if values[-1] == 'BC': + return date.min + value = values[0] + if len(value) > 10: + return date.max + format = connection.date_format() + return datetime.strptime(value, format).date() + + +def cast_time(value: str) -> Any: + """Cast a time value.""" + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + return datetime.strptime(value, format).time() + + +_re_timezone = regex('(.*)([+-].*)') + + +def cast_timetz(value: str) -> Any: + """Cast a timetz value.""" + m = _re_timezone.match(value) + if m: + value, tz = m.groups() + else: + tz = '+0000' + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + value += timezone_as_offset(tz) + format += '%z' + return datetime.strptime(value, format).timetz() + + +def cast_timestamp(value: str, connection: Connection) -> Any: + """Cast a timestamp value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:5] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + else: + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +def cast_timestamptz(value: str, connection: Connection) -> Any: + """Cast a timestamptz value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + values, tz = values[:-1], values[-1] + else: + if format.startswith('%Y-'): + m = _re_timezone.match(values[1]) + if m: + values[1], tz = m.groups() + else: + tz = '+0000' + else: + values, tz = values[:-1], values[-1] + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + values.append(timezone_as_offset(tz)) + formats.append('%z') + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +_re_interval_sql_standard = regex( + '(?:([+-])?([0-9]+)-([0-9]+) ?)?' + '(?:([+-]?[0-9]+)(?!:) ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres = regex( + '(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres_verbose = regex( + '@ ?(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-]?[0-9]+) ?hours? ?)?' + '(?:([+-]?[0-9]+) ?mins? ?)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') + +_re_interval_iso_8601 = regex( + 'P(?:([+-]?[0-9]+)Y)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-]?[0-9]+)D)?' + '(?:T(?:([+-]?[0-9]+)H)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') + + +def cast_interval(value: str) -> timedelta: + """Cast an interval value.""" + # The output format depends on the server setting IntervalStyle, but it's + # not necessary to consult this setting to parse it. It's faster to just + # check all possible formats, and there is no ambiguity here. + m = _re_interval_iso_8601.match(value) + if m: + s = [v or '0' for v in m.groups()] + secs_ago = s.pop(5) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = -secs + usecs = -usecs + else: + m = _re_interval_postgres_verbose.match(value) + if m: + s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) + secs_ago = s.pop(5) == '-' + d = [-int(v) for v in s] if ago else [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = - secs + usecs = -usecs + else: + m = _re_interval_postgres.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + m = _re_interval_sql_standard.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + years_ago = s.pop(0) == '-' + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if years_ago: + years = -years + mons = -mons + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + raise ValueError(f'Cannot parse interval: {value}') + days += 365 * years + 30 * mons + return timedelta(days=days, hours=hours, minutes=mins, + seconds=secs, microseconds=usecs) + + +class Typecasts(dict): + """Dictionary mapping database types to typecast functions. + + The cast functions get passed the string representation of a value in + the database which they need to convert to a Python object. The + passed string will never be None since NULL values are already + handled before the cast function is called. + + Note that the basic types are already handled by the C extension. + They only need to be handled here as record or array components. + """ + + # the default cast functions + # (str functions are ignored but have been added for faster access) + defaults: ClassVar[dict[str, Callable]] = { + 'char': str, 'bpchar': str, 'name': str, + 'text': str, 'varchar': str, 'sql_identifier': str, + 'bool': cast_bool, 'bytea': unescape_bytea, + 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, + 'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json, + 'float4': float, 'float8': float, + 'numeric': cast_num, 'money': cast_money, + 'date': cast_date, 'interval': cast_interval, + 'time': cast_time, 'timetz': cast_timetz, + 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, + 'int2vector': cast_int2vector, 'uuid': UUID, + 'anyarray': cast_array, 'record': cast_record} # pyright: ignore + + connection: Connection | None = None # set in connection specific instance + + def __missing__(self, typ: str) -> Callable | None: + """Create a cast function if it is not cached. + + Note that this class never raises a KeyError, + but returns None when no special cast function exists. + """ + if not isinstance(typ, str): + raise TypeError(f'Invalid type: {typ}') + cast: Callable | None = self.defaults.get(typ) + if cast: + # store default for faster access + cast = self._add_connection(cast) + self[typ] = cast + elif typ.startswith('_'): + base_cast = self[typ[1:]] + cast = self.create_array_cast(base_cast) + if base_cast: + self[typ] = cast + else: + attnames = self.get_attnames(typ) + if attnames: + casts = [self[v.pgtype] for v in attnames.values()] + cast = self.create_record_cast(typ, attnames, casts) + self[typ] = cast + return cast + + @staticmethod + def _needs_connection(func: Callable) -> bool: + """Check if a typecast function needs a connection argument.""" + try: + args = get_args(func) + except (TypeError, ValueError): + return False + return 'connection' in args[1:] + + def _add_connection(self, cast: Callable) -> Callable: + """Add a connection argument to the typecast function if necessary.""" + if not self.connection or not self._needs_connection(cast): + return cast + return partial(cast, connection=self.connection) + + def get(self, typ: str, default: Callable | None = None # type: ignore + ) -> Callable | None: + """Get the typecast function for the given database type.""" + return self[typ] or default + + def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a typecast function for the specified database type(s).""" + if isinstance(typ, str): + typ = [typ] + if cast is None: + for t in typ: + self.pop(t, None) + self.pop(f'_{t}', None) + else: + if not callable(cast): + raise TypeError("Cast parameter must be callable") + for t in typ: + self[t] = self._add_connection(cast) + self.pop(f'_{t}', None) + + def reset(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecasts for the specified type(s) to their defaults. + + When no type is specified, all typecasts will be reset. + """ + if typ is None: + self.clear() + else: + if isinstance(typ, str): + typ = [typ] + for t in typ: + self.pop(t, None) + + @classmethod + def get_default(cls, typ: str) -> Any: + """Get the default typecast function for the given database type.""" + return cls.defaults.get(typ) + + @classmethod + def set_default(cls, typ: str | Sequence[str], + cast: Callable | None) -> None: + """Set a default typecast function for the given database type(s).""" + if isinstance(typ, str): + typ = [typ] + defaults = cls.defaults + if cast is None: + for t in typ: + defaults.pop(t, None) + defaults.pop(f'_{t}', None) + else: + if not callable(cast): + raise TypeError("Cast parameter must be callable") + for t in typ: + defaults[t] = cast + defaults.pop(f'_{t}', None) + + # noinspection PyMethodMayBeStatic,PyUnusedLocal + def get_attnames(self, typ: Any) -> AttrDict: + """Return the fields for the given record type. + + This method will be replaced with the get_attnames() method of DbTypes. + """ + return AttrDict() + + # noinspection PyMethodMayBeStatic + def dateformat(self) -> str: + """Return the current date format. + + This method will be replaced with the dateformat() method of DbTypes. + """ + return '%Y-%m-%d' + + def create_array_cast(self, basecast: Callable) -> Callable: + """Create an array typecast for the given base cast.""" + cast_array = self['anyarray'] + + def cast(v: Any) -> list: + return cast_array(v, basecast) + return cast + + def create_record_cast(self, name: str, fields: AttrDict, + casts: list[Callable]) -> Callable: + """Create a named record typecast for the given fields and casts.""" + cast_record = self['record'] + record = namedtuple(name, fields) # type: ignore + + def cast(v: Any) -> record: + # noinspection PyArgumentList + return record(*cast_record(v, casts)) + return cast + + +def get_typecast(typ: str) -> Callable | None: + """Get the global typecast function for the given database type.""" + return Typecasts.get_default(typ) + + +def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a global typecast function for the given database type(s). + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call db.db_types.reset_typecast(). + """ + Typecasts.set_default(typ, cast) diff --git a/pg/core.py b/pg/core.py new file mode 100644 index 00000000..3eb8f745 --- /dev/null +++ b/pg/core.py @@ -0,0 +1,135 @@ +"""Core functionality from extension module.""" + +try: + from ._pg import version +except ImportError as e: # noqa: F841 + import os + libpq = 'libpq.' + if os.name == 'nt': + libpq += 'dll' + import sys + paths = [path for path in os.environ["PATH"].split(os.pathsep) + if os.path.exists(os.path.join(path, libpq))] + if sys.version_info >= (3, 8): + # see https://docs.python.org/3/whatsnew/3.8.html#ctypes + add_dll_dir = os.add_dll_directory # type: ignore + for path in paths: + with add_dll_dir(os.path.abspath(path)): + try: + from ._pg import version + except ImportError: + pass + else: + del version + e = None # type: ignore + break + if paths: + libpq = 'compatible ' + libpq + else: + libpq += 'so' + if e: + raise ImportError( + "Cannot import shared library for PyGreSQL,\n" + f"probably because no {libpq} is installed.\n{e}") from e +else: + del version + +# import objects from extension module +from ._pg import ( + INV_READ, + INV_WRITE, + POLLING_FAILED, + POLLING_OK, + POLLING_READING, + POLLING_WRITING, + RESULT_DDL, + RESULT_DML, + RESULT_DQL, + RESULT_EMPTY, + SEEK_CUR, + SEEK_END, + SEEK_SET, + TRANS_ACTIVE, + TRANS_IDLE, + TRANS_INERROR, + TRANS_INTRANS, + TRANS_UNKNOWN, + Connection, + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + InvalidResultError, + MultipleResultsError, + NoResultError, + NotSupportedError, + OperationalError, + ProgrammingError, + Query, + Warning, + cast_array, + cast_hstore, + cast_record, + connect, + escape_bytea, + escape_string, + get_array, + get_bool, + get_bytea_escaped, + get_datestyle, + get_decimal, + get_decimal_point, + get_defbase, + get_defhost, + get_defopt, + get_defport, + get_defuser, + get_jsondecode, + get_pqlib_version, + set_array, + set_bool, + set_bytea_escaped, + set_datestyle, + set_decimal, + set_decimal_point, + set_defbase, + set_defhost, + set_defopt, + set_defpasswd, + set_defport, + set_defuser, + set_jsondecode, + set_query_helpers, + unescape_bytea, + version, +) + +__all__ = [ + 'Error', 'Warning', + 'DataError', 'DatabaseError', + 'IntegrityError', 'InterfaceError', 'InternalError', + 'InvalidResultError', 'MultipleResultsError', + 'NoResultError', 'NotSupportedError', + 'OperationalError', 'ProgrammingError', + 'Connection', 'Query', + 'INV_READ', 'INV_WRITE', + 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', + 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', + 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', + 'TRANS_ACTIVE', 'TRANS_IDLE', 'TRANS_INERROR', + 'TRANS_INTRANS', 'TRANS_UNKNOWN', + 'cast_array', 'cast_hstore', 'cast_record', + 'connect', 'escape_bytea', 'escape_string', 'unescape_bytea', + 'get_array', 'get_bool', 'get_bytea_escaped', + 'get_datestyle', 'get_decimal', 'get_decimal_point', + 'get_defbase', 'get_defhost', 'get_defopt', 'get_defport', 'get_defuser', + 'get_jsondecode', 'get_pqlib_version', + 'set_array', 'set_bool', 'set_bytea_escaped', + 'set_datestyle', 'set_decimal', 'set_decimal_point', + 'set_defbase', 'set_defhost', 'set_defopt', + 'set_defpasswd', 'set_defport', 'set_defuser', + 'set_jsondecode', 'set_query_helpers', + 'version', +] diff --git a/pg/db.py b/pg/db.py new file mode 100644 index 00000000..ce7915f8 --- /dev/null +++ b/pg/db.py @@ -0,0 +1,1332 @@ +"""Connection wrapper.""" + +from __future__ import annotations + +from contextlib import suppress +from json import dumps as jsonencode +from json import loads as jsondecode +from operator import itemgetter +from typing import Any, Callable, Iterator, Sequence + +from . import Connection, connect +from .adapt import Adapter, DbTypes +from .attrs import AttrDict +from .core import ( + InternalError, + ProgrammingError, + Query, + get_bool, + get_jsondecode, + unescape_bytea, +) +from .error import db_error, int_error, prg_error +from .helpers import namediter, oid_key, quote_if_unqualified +from .notify import NotificationHandler + +__all__ = ['DB'] + +# The actual PostgreSQL database connection interface: + +class DB: + """Wrapper class for the _pg connection type.""" + + db: Connection | None = None # invalid fallback for underlying connection + _db_args: Any # either the connect args or the underlying connection + + def __init__(self, *args: Any, **kw: Any) -> None: + """Create a new connection. + + You can pass either the connection parameters or an existing + _pg or pgdb connection. This allows you to use the methods + of the classic pg interface with a DB-API 2 pgdb connection. + """ + if not args and len(kw) == 1: + db = kw.get('db') + elif not kw and len(args) == 1: + db = args[0] + else: + db = None + if db: + if isinstance(db, DB): + db = db.db + else: + with suppress(AttributeError): + # noinspection PyUnresolvedReferences + db = db._cnx + if not db or not hasattr(db, 'db') or not hasattr(db, 'query'): + db = connect(*args, **kw) + self._db_args = args, kw + self._closeable = True + else: + self._db_args = db + self._closeable = False + self.db = db + self.dbname = db.db + self._regtypes = False + self._attnames: dict[str, AttrDict] = {} + self._generated: dict[str, frozenset[str]] = {} + self._pkeys: dict[str, str | tuple[str, ...]] = {} + self._privileges: dict[tuple[str, str], bool] = {} + self.adapter = Adapter(self) + self.dbtypes = DbTypes(self) + self._query_attnames = ( + "SELECT a.attname," + " t.oid, t.typname, t.oid::pg_catalog.regtype," + " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" + " FROM pg_catalog.pg_attribute a" + " JOIN pg_catalog.pg_type t" + " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" + " WHERE a.attrelid OPERATOR(pg_catalog.=)" + " {}::pg_catalog.regclass" + " AND {} AND NOT a.attisdropped ORDER BY a.attnum") + if db.server_version < 120000: + self._query_generated = ( + "a.attidentity OPERATOR(pg_catalog.=) 'a'" + ) + else: + self._query_generated = ( + "(a.attidentity OPERATOR(pg_catalog.=) 'a' OR" + " a.attgenerated OPERATOR(pg_catalog.!=) '')" + ) + db.set_cast_hook(self.dbtypes.typecast) + # For debugging scripts, self.debug can be set + # * to a string format specification (e.g. in CGI set to "%s
"), + # * to a file object to write debug statements or + # * to a callable object which takes a string argument + # * to any other true value to just print debug statements + self.debug: Any = None + + def __getattr__(self, name: str) -> Any: + """Get the specified attritbute of the connection.""" + # All undefined members are same as in underlying connection: + if self.db: + return getattr(self.db, name) + else: + raise int_error('Connection is not valid') + + def __dir__(self) -> list[str]: + """List all attributes of the connection.""" + # Custom dir function including the attributes of the connection: + attrs = set(self.__class__.__dict__) + attrs.update(self.__dict__) + attrs.update(dir(self.db)) + return sorted(attrs) + + # Context manager methods + + def __enter__(self) -> DB: + """Enter the runtime context. This will start a transaction.""" + self.begin() + return self + + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: + """Exit the runtime context. This will end the transaction.""" + if et is None and ev is None and tb is None: + self.commit() + else: + self.rollback() + + def __del__(self) -> None: + """Delete the connection.""" + try: + db = self.db + except AttributeError: + db = None + if db: + with suppress(TypeError): # when already closed + db.set_cast_hook(None) + if self._closeable: + with suppress(InternalError): # when already closed + db.close() + + # Auxiliary methods + + def _do_debug(self, *args: Any) -> None: + """Print a debug message.""" + if self.debug: + s = '\n'.join(str(arg) for arg in args) + if isinstance(self.debug, str): + print(self.debug % s) + elif hasattr(self.debug, 'write'): + # noinspection PyCallingNonCallable + self.debug.write(s + '\n') + elif callable(self.debug): + self.debug(s) + else: + print(s) + + def _escape_qualified_name(self, s: str) -> str: + """Escape a qualified name. + + Escapes the name for use as an SQL identifier, unless the + name contains a dot, in which case the name is ambiguous + (could be a qualified name or just a name with a dot in it) + and must be quoted manually by the caller. + """ + if '.' not in s: + s = self.escape_identifier(s) + return s + + @staticmethod + def _make_bool(d: Any) -> bool | str: + """Get boolean value corresponding to d.""" + return bool(d) if get_bool() else ('t' if d else 'f') + + @staticmethod + def _list_params(params: Sequence) -> str: + """Create a human readable parameter list.""" + return ', '.join(f'${n}={v!r}' for n, v in enumerate(params, 1)) + + @property + def _valid_db(self) -> Connection: + """Get underlying connection and make sure it is not closed.""" + db = self.db + if not db: + raise int_error('Connection already closed') + return db + + # Public methods + + # escape_string and escape_bytea exist as methods, + # so we define unescape_bytea as a method as well + unescape_bytea = staticmethod(unescape_bytea) + + @staticmethod + def decode_json(s: str) -> Any: + """Decode a JSON string coming from the database.""" + return (get_jsondecode() or jsondecode)(s) + + @staticmethod + def encode_json(d: Any) -> str: + """Encode a JSON string for use within SQL.""" + return jsonencode(d) + + def close(self) -> None: + """Close the database connection.""" + # Wraps shared library function so we can track state. + db = self._valid_db + with suppress(TypeError): # when already closed + db.set_cast_hook(None) + if self._closeable: + db.close() + self.db = None + + def reset(self) -> None: + """Reset connection with current parameters. + + All derived queries and large objects derived from this connection + will not be usable after this call. + """ + self._valid_db.reset() + + def reopen(self) -> None: + """Reopen connection to the database. + + Used in case we need another connection to the same database. + Note that we can still reopen a database that we have closed. + """ + # There is no such shared library function. + if self._closeable: + args, kw = self._db_args + db = connect(*args, **kw) + if self.db: + self.db.set_cast_hook(None) + self.db.close() + db.set_cast_hook(self.dbtypes.typecast) + self.db = db + else: + self.db = self._db_args + + def begin(self, mode: str | None = None) -> Query: + """Begin a transaction.""" + qstr = 'BEGIN' + if mode: + qstr += ' ' + mode + return self.query(qstr) + + start = begin + + def commit(self) -> Query: + """Commit the current transaction.""" + return self.query('COMMIT') + + end = commit + + def rollback(self, name: str | None = None) -> Query: + """Roll back the current transaction.""" + qstr = 'ROLLBACK' + if name: + qstr += ' TO ' + name + return self.query(qstr) + + abort = rollback + + def savepoint(self, name: str) -> Query: + """Define a new savepoint within the current transaction.""" + return self.query('SAVEPOINT ' + name) + + def release(self, name: str) -> Query: + """Destroy a previously defined savepoint.""" + return self.query('RELEASE ' + name) + + def get_parameter(self, + parameter: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str] | dict[str, Any] + ) -> str | list[str] | dict[str, str]: + """Get the value of a run-time parameter. + + If the parameter is a string, the return value will also be a string + that is the current setting of the run-time parameter with that name. + + You can get several parameters at once by passing a list, set or dict. + When passing a list of parameter names, the return value will be a + corresponding list of parameter settings. When passing a set of + parameter names, a new dict will be returned, mapping these parameter + names to their settings. Finally, if you pass a dict as parameter, + its values will be set to the current parameter settings corresponding + to its keys. + + By passing the special name 'all' as the parameter, you can get a dict + of all existing configuration parameters. + """ + values: Any + if isinstance(parameter, str): + parameter = [parameter] + values = None + elif isinstance(parameter, (list, tuple)): + values = [] + elif isinstance(parameter, (set, frozenset)): + values = {} + elif isinstance(parameter, dict): + values = parameter + else: + raise TypeError( + 'The parameter must be a string, list, set or dict') + if not parameter: + raise TypeError('No parameter has been specified') + query = self._valid_db.query + params: Any = {} if isinstance(values, dict) else [] + for param_key in parameter: + param = param_key.strip().lower() if isinstance( + param_key, (bytes, str)) else None + if not param: + raise TypeError('Invalid parameter') + if param == 'all': + cmd = 'SHOW ALL' + values = query(cmd).getresult() + values = {value[0]: value[1] for value in values} + break + if isinstance(params, dict): + params[param] = param_key + else: + params.append(param) + else: + for param in params: + cmd = f'SHOW {param}' + value = query(cmd).singlescalar() + if values is None: + values = value + elif isinstance(values, list): + values.append(value) + else: + values[params[param]] = value + return values + + def set_parameter(self, + parameter: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str] | dict[str, Any], + value: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str]| None = None, + local: bool = False) -> None: + """Set the value of a run-time parameter. + + If the parameter and the value are strings, the run-time parameter + will be set to that value. If no value or None is passed as a value, + then the run-time parameter will be restored to its default value. + + You can set several parameters at once by passing a list of parameter + names, together with a single value that all parameters should be + set to or with a corresponding list of values. You can also pass + the parameters as a set if you only provide a single value. + Finally, you can pass a dict with parameter names as keys. In this + case, you should not pass a value, since the values for the parameters + will be taken from the dict. + + By passing the special name 'all' as the parameter, you can reset + all existing settable run-time parameters to their default values. + + If you set local to True, then the command takes effect for only the + current transaction. After commit() or rollback(), the session-level + setting takes effect again. Setting local to True will appear to + have no effect if it is executed outside a transaction, since the + transaction will end immediately. + """ + if isinstance(parameter, str): + parameter = {parameter: value} + elif isinstance(parameter, (list, tuple)): + if isinstance(value, (list, tuple)): + parameter = dict(zip(parameter, value)) + else: + parameter = dict.fromkeys(parameter, value) + elif isinstance(parameter, (set, frozenset)): + if isinstance(value, (list, tuple, set, frozenset)): + value = set(value) + if len(value) == 1: + value = next(iter(value)) + if not (value is None or isinstance(value, str)): + raise ValueError( + 'A single value must be specified' + ' when parameter is a set') + parameter = dict.fromkeys(parameter, value) + elif isinstance(parameter, dict): + if value is not None: + raise ValueError( + 'A value must not be specified' + ' when parameter is a dictionary') + else: + raise TypeError( + 'The parameter must be a string, list, set or dict') + if not parameter: + raise TypeError('No parameter has been specified') + params: dict[str, str | None] = {} + for param, param_value in parameter.items(): + param = param.strip().lower() if isinstance(param, str) else None + if not param: + raise TypeError('Invalid parameter') + if param == 'all': + if param_value is not None: + raise ValueError( + 'A value must not be specified' + " when parameter is 'all'") + params = {'all': None} + break + params[param] = param_value + local_clause = ' LOCAL' if local else '' + for param, param_value in params.items(): + cmd = (f'RESET{local_clause} {param}' + if param_value is None else + f'SET{local_clause} {param} TO {param_value}') + self._do_debug(cmd) + self._valid_db.query(cmd) + + def query(self, command: str, *args: Any) -> Query: + """Execute a SQL command string. + + This method simply sends a SQL query to the database. If the query is + an insert statement that inserted exactly one row into a table that + has OIDs, the return value is the OID of the newly inserted row. + If the query is an update or delete statement, or an insert statement + that did not insert exactly one row in a table with OIDs, then the + number of rows affected is returned as a string. If it is a statement + that returns rows as a result (usually a select statement, but maybe + also an "insert/update ... returning" statement), this method returns + a Query object that can be accessed via getresult() or dictresult() + or simply printed. Otherwise, it returns `None`. + + The query can contain numbered parameters of the form $1 in place + of any data constant. Arguments given after the query string will + be substituted for the corresponding numbered parameter. Parameter + values can also be given as a single list or tuple argument. + """ + # Wraps shared library function for debugging. + db = self._valid_db + if args: + self._do_debug(command, args) + return db.query(command, args) + self._do_debug(command) + return db.query(command) + + def query_formatted(self, command: str, + parameters: tuple | list | dict | None = None, + types: tuple | list | dict | None = None, + inline: bool =False) -> Query: + """Execute a formatted SQL command string. + + Similar to query, but using Python format placeholders of the form + %s or %(names)s instead of PostgreSQL placeholders of the form $1. + The parameters must be passed as a tuple, list or dict. You can + also pass a corresponding tuple, list or dict of database types in + order to format the parameters properly in case there is ambiguity. + + If you set inline to True, the parameters will be sent to the database + embedded in the SQL command, otherwise they will be sent separately. + """ + return self.query(*self.adapter.format_query( + command, parameters, types, inline)) + + def query_prepared(self, name: str, *args: Any) -> Query: + """Execute a prepared SQL statement. + + This works like the query() method, except that instead of passing + the SQL command, you pass the name of a prepared statement. If you + pass an empty name, the unnamed statement will be executed. + """ + if name is None: + name = '' + db = self._valid_db + if args: + self._do_debug('EXECUTE', name, args) + return db.query_prepared(name, args) + self._do_debug('EXECUTE', name) + return db.query_prepared(name) + + def prepare(self, name: str, command: str) -> None: + """Create a prepared SQL statement. + + This creates a prepared statement for the given command with the + given name for later execution with the query_prepared() method. + + The name can be empty to create an unnamed statement, in which case + any pre-existing unnamed statement is automatically replaced; + otherwise it is an error if the statement name is already + defined in the current database session. We recommend always using + named queries, since unnamed queries have a limited lifetime and + can be automatically replaced or destroyed by various operations. + """ + if name is None: + name = '' + self._do_debug('prepare', name, command) + self._valid_db.prepare(name, command) + + def describe_prepared(self, name: str | None = None) -> Query: + """Describe a prepared SQL statement. + + This method returns a Query object describing the result columns of + the prepared statement with the given name. If you omit the name, + the unnamed statement will be described if you created one before. + """ + if name is None: + name = '' + return self._valid_db.describe_prepared(name) + + def delete_prepared(self, name: str | None = None) -> Query: + """Delete a prepared SQL statement. + + This deallocates a previously prepared SQL statement with the given + name, or deallocates all prepared statements if you do not specify a + name. Note that prepared statements are also deallocated automatically + when the current session ends. + """ + if not name: + name = 'ALL' + cmd = f"DEALLOCATE {name}" + self._do_debug(cmd) + return self._valid_db.query(cmd) + + def pkey(self, table: str, composite: bool = False, flush: bool = False + ) -> str | tuple[str, ...]: + """Get the primary key of a table. + + Single primary keys are returned as strings unless you + set the composite flag. Composite primary keys are always + represented as tuples. Note that this raises a KeyError + if the table does not have a primary key. + + If flush is set then the internal cache for primary keys will + be flushed. This may be necessary after the database schema or + the search path has been changed. + """ + pkeys = self._pkeys + if flush: + pkeys.clear() + self._do_debug('The pkey cache has been flushed') + try: # cache lookup + pkey = pkeys[table] + except KeyError as e: # cache miss, check the database + cmd = ("SELECT" # noqa: S608 + " a.attname, a.attnum, i.indkey" + " FROM pg_catalog.pg_index i" + " JOIN pg_catalog.pg_attribute a" + " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" + " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" + " AND NOT a.attisdropped" + " WHERE i.indrelid OPERATOR(pg_catalog.=)" + " {}::pg_catalog.regclass" + " AND i.indisprimary ORDER BY a.attnum").format( + quote_if_unqualified('$1', table)) + res = self._valid_db.query(cmd, (table,)).getresult() + if not res: + raise KeyError(f'Table {table} has no primary key') from e + # we want to use the order defined in the primary key index here, + # not the order as defined by the columns in the table + if len(res) > 1: + indkey = res[0][2] + pkey = tuple(row[0] for row in sorted( + res, key=lambda row: indkey.index(row[1]))) + else: + pkey = res[0][0] + pkeys[table] = pkey # cache it + if composite and not isinstance(pkey, tuple): + pkey = (pkey,) + return pkey + + def pkeys(self, table: str) -> tuple[str, ...]: + """Get the primary key of a table as a tuple. + + Same as pkey() with 'composite' set to True. + """ + return self.pkey(table, True) # type: ignore + + def get_databases(self) -> list[str]: + """Get list of databases in the system.""" + return [r[0] for r in self._valid_db.query( + 'SELECT datname FROM pg_catalog.pg_database').getresult()] + + def get_relations(self, kinds: str | Sequence[str] | None = None, + system: bool = False) -> list[str]: + """Get list of relations in connected database of specified kinds. + + If kinds is None or empty, all kinds of relations are returned. + Otherwise, kinds can be a string or sequence of type letters + specifying which kind of relations you want to list. + + Set the system flag if you want to get the system relations as well. + """ + where_parts = [] + if kinds: + where_parts.append( + "r.relkind IN ({})".format(','.join(f"'{k}'" for k in kinds))) + if not system: + where_parts.append("s.nspname NOT SIMILAR" + " TO 'pg/_%|information/_schema' ESCAPE '/'") + where = " WHERE " + ' AND '.join(where_parts) if where_parts else '' + cmd = ("SELECT" # noqa: S608 + " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" + " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" + " FROM pg_catalog.pg_class r" + " JOIN pg_catalog.pg_namespace s" + f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" + " ORDER BY s.nspname, r.relname") + return [r[0] for r in self._valid_db.query(cmd).getresult()] + + def get_tables(self, system: bool = False) -> list[str]: + """Return list of tables in connected database. + + Set the system flag if you want to get the system tables as well. + """ + return self.get_relations('r', system) + + def get_attnames(self, table: str, with_oid: bool=True, flush: bool=False + ) -> AttrDict: + """Given the name of a table, dig out the set of attribute names. + + Returns a read-only dictionary of attribute names (the names are + the keys, the values are the names of the attributes' types) + with the column names in the proper order if you iterate over it. + + If flush is set, then the internal cache for attribute names will + be flushed. This may be necessary after the database schema or + the search path has been changed. + + By default, only a limited number of simple types will be returned. + You can get the registered types after calling use_regtypes(True). + """ + attnames = self._attnames + if flush: + attnames.clear() + self._do_debug('The attnames cache has been flushed') + try: # cache lookup + names = attnames[table] + except KeyError: # cache miss, check the database + cmd = "a.attnum OPERATOR(pg_catalog.>) 0" + if with_oid: + cmd = f"({cmd} OR a.attname OPERATOR(pg_catalog.=) 'oid')" + cmd = self._query_attnames.format( + quote_if_unqualified('$1', table), cmd) + res = self._valid_db.query(cmd, (table,)).getresult() + types = self.dbtypes + names = AttrDict((name[0], types.add(*name[1:])) for name in res) + attnames[table] = names # cache it + return names + + def get_generated(self, table: str, flush: bool = False) -> frozenset[str]: + """Given the name of a table, dig out the set of generated columns. + + Returns a set of column names that are generated and unalterable. + + If flush is set, then the internal cache for generated columns will + be flushed. This may be necessary after the database schema or + the search path has been changed. + """ + generated = self._generated + if flush: + generated.clear() + self._do_debug('The generated cache has been flushed') + try: # cache lookup + names = generated[table] + except KeyError: # cache miss, check the database + cmd = "a.attnum OPERATOR(pg_catalog.>) 0" + cmd = f"{cmd} AND {self._query_generated}" + cmd = self._query_attnames.format( + quote_if_unqualified('$1', table), cmd) + res = self._valid_db.query(cmd, (table,)).getresult() + names = frozenset(name[0] for name in res) + generated[table] = names # cache it + return names + + def use_regtypes(self, regtypes: bool | None = None) -> bool: + """Use registered type names instead of simplified type names.""" + if regtypes is None: + return self.dbtypes._regtypes + regtypes = bool(regtypes) + if regtypes != self.dbtypes._regtypes: + self.dbtypes._regtypes = regtypes + self._attnames.clear() + self.dbtypes.clear() + return regtypes + + def has_table_privilege(self, table: str, privilege: str = 'select', + flush: bool = False) -> bool: + """Check whether current user has specified table privilege. + + If flush is set, then the internal cache for table privileges will + be flushed. This may be necessary after privileges have been changed. + """ + privileges = self._privileges + if flush: + privileges.clear() + self._do_debug('The privileges cache has been flushed') + privilege = privilege.lower() + try: # ask cache + ret = privileges[table, privilege] + except KeyError: # cache miss, ask the database + cmd = "SELECT pg_catalog.has_table_privilege({}, $2)".format( + quote_if_unqualified('$1', table)) + query = self._valid_db.query(cmd, (table, privilege)) + ret = query.singlescalar() == self._make_bool(True) + privileges[table, privilege] = ret # cache it + return ret + + def get(self, table: str, row: Any, + keyname: str | tuple[str, ...] | None = None) -> dict[str, Any]: + """Get a row from a database table or view. + + This method is the basic mechanism to get a single row. It assumes + that the keyname specifies a unique row. It must be the name of a + single column or a tuple of column names. If the keyname is not + specified, then the primary key for the table is used. + + If row is a dictionary, then the value for the key is taken from it. + Otherwise, the row must be a single value or a tuple of values + corresponding to the passed keyname or primary key. The fetched row + from the table will be returned as a new dictionary or used to replace + the existing values when row was passed as a dictionary. + + The OID is also put into the dictionary if the table has one, but + in order to allow the caller to work with multiple tables, it is + munged as "oid(table)" using the actual name of the table. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + attnames = self.get_attnames(table) + qoid = oid_key(table) if 'oid' in attnames else None + if keyname and isinstance(keyname, str): + keyname = (keyname,) + if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row: + row['oid'] = row[qoid] + if not keyname: + try: # if keyname is not specified, try using the primary key + keyname = self.pkeys(table) + except KeyError as e: # the table has no primary key + # try using the oid instead + if qoid and isinstance(row, dict) and 'oid' in row: + keyname = ('oid',) + else: + raise prg_error( + f'Table {table} has no primary key') from e + else: # the table has a primary key + # check whether all key columns have values + if isinstance(row, dict) and not set(keyname).issubset(row): + # try using the oid instead + if qoid and 'oid' in row: + keyname = ('oid',) + else: + raise KeyError( + 'Missing value in row for specified keyname') + if not isinstance(row, dict): + if not isinstance(row, (tuple, list)): + row = [row] + if len(keyname) != len(row): + raise KeyError( + 'Differing number of items in keyname and row') + row = dict(zip(keyname, row)) + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + what = 'oid, *' if qoid else '*' + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( + col(k), adapt(row[k], attnames[k])) for k in keyname) + if 'oid' in row: + if qoid: + row[qoid] = row['oid'] + del row['oid'] + t = self._escape_qualified_name(table) + cmd = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if not res: + # make where clause in error message better readable + where = where.replace('OPERATOR(pg_catalog.=)', '=') + raise db_error( + f'No such record in {table}\nwhere {where}\nwith ' + + self._list_params(params)) + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + return row + + def insert(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> dict[str, Any]: + """Insert a row into a database table. + + This method inserts a row into a table. The name of the table must + be passed as the first parameter. The other parameters are used for + providing the data of the row that shall be inserted into the table. + If a dictionary is supplied as the second parameter, it starts with + that. Otherwise, it uses a blank dictionary. + Either way the dictionary is updated from the keywords. + + The dictionary is then reloaded with the values actually inserted in + order to pick up values modified by rules, triggers, etc. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + if row is None: + row = {} + row.update(kw) + if 'oid' in row: + del row['oid'] # do not insert oid + attnames = self.get_attnames(table) + generated = self.get_generated(table) + qoid = oid_key(table) if 'oid' in attnames else None + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + name_list, value_list = [], [] + for n in attnames: + if n in row and n not in generated: + name_list.append(col(n)) + value_list.append(adapt(row[n], attnames[n])) + if not name_list: + raise prg_error('No column found that can be inserted') + names, values = ', '.join(name_list), ', '.join(value_list) + ret = 'oid, *' if qoid else '*' + t = self._escape_qualified_name(table) + cmd = (f'INSERT INTO {t} ({names})' # noqa: S608 + f' VALUES ({values}) RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if res: # this should always be true + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + return row + + def update(self, table: str, row: dict[str, Any] | None = None, **kw : Any + ) -> dict[str, Any]: + """Update an existing row in a database table. + + Similar to insert, but updates an existing row. The update is based + on the primary key of the table or the OID value as munged by get() + or passed as keyword. The OID will take precedence if provided, so + that it is possible to update the primary key itself. + + The dictionary is then modified to reflect any changes caused by the + update due to triggers, rules, default values, etc. + """ + if table.endswith('*'): + table = table[:-1].rstrip() # need parent table name + attnames = self.get_attnames(table) + generated = self.get_generated(table) + qoid = oid_key(table) if 'oid' in attnames else None + if row is None: + row = {} + elif 'oid' in row: + del row['oid'] # only accept oid key from named args for safety + row.update(kw) + if qoid and qoid in row and 'oid' not in row: + row['oid'] = row[qoid] + if qoid and 'oid' in row: # try using the oid + keynames: tuple[str, ...] = ('oid',) + keyset = set(keynames) + else: # try using the primary key + try: + keynames = self.pkeys(table) + except KeyError as e: # the table has no primary key + raise prg_error(f'Table {table} has no primary key') from e + keyset = set(keynames) + # check whether all key columns have values + if not keyset.issubset(row): + raise KeyError('Missing value for primary key in row') + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( + col(k), adapt(row[k], attnames[k])) for k in keynames) + if 'oid' in row: + if qoid: + row[qoid] = row['oid'] + del row['oid'] + values_list = [] + for n in attnames: + if n in row and n not in keyset and n not in generated: + values_list.append(f'{col(n)} = {adapt(row[n], attnames[n])}') + if not values_list: + return row + values = ', '.join(values_list) + ret = 'oid, *' if qoid else '*' + t = self._escape_qualified_name(table) + cmd = (f'UPDATE {t} SET {values}' # noqa: S608 + f' WHERE {where} RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if res: # may be empty when row does not exist + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + return row + + def upsert(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> dict[str, Any]: + """Insert a row into a database table with conflict resolution. + + This method inserts a row into a table, but instead of raising a + ProgrammingError exception in case a row with the same primary key + already exists, an update will be executed instead. This will be + performed as a single atomic operation on the database, so race + conditions can be avoided. + + Like the insert method, the first parameter is the name of the + table and the second parameter can be used to pass the values to + be inserted as a dictionary. + + Unlike the insert und update statement, keyword parameters are not + used to modify the dictionary, but to specify which columns shall + be updated in case of a conflict, and in which way: + + A value of False or None means the column shall not be updated, + a value of True means the column shall be updated with the value + that has been proposed for insertion, i.e. has been passed as value + in the dictionary. Columns that are not specified by keywords but + appear as keys in the dictionary are also updated like in the case + keywords had been passed with the value True. + + So if in the case of a conflict you want to update every column + that has been passed in the dictionary row, you would call + upsert(table, row). If you don't want to do anything in case + of a conflict, i.e. leave the existing row as it is, call + upsert(table, row, **dict.fromkeys(row)). + + If you need more fine-grained control of what gets updated, you can + also pass strings in the keyword parameters. These strings will + be used as SQL expressions for the update columns. In these + expressions you can refer to the value that already exists in + the table by prefixing the column name with "included.", and to + the value that has been proposed for insertion by prefixing the + column name with the "excluded." + + The dictionary is modified in any case to reflect the values in + the database after the operation has completed. + + Note: The method uses the PostgreSQL "upsert" feature which is + only available since PostgreSQL 9.5. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + if row is None: + row = {} + if 'oid' in row: + del row['oid'] # do not insert oid + if 'oid' in kw: + del kw['oid'] # do not update oid + attnames = self.get_attnames(table) + generated = self.get_generated(table) + qoid = oid_key(table) if 'oid' in attnames else None + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + name_list, value_list = [], [] + for n in attnames: + if n in row and n not in generated: + name_list.append(col(n)) + value_list.append(adapt(row[n], attnames[n])) + names, values = ', '.join(name_list), ', '.join(value_list) + try: + keynames = self.pkeys(table) + except KeyError as e: + raise prg_error(f'Table {table} has no primary key') from e + target = ', '.join(col(k) for k in keynames) + update = [] + keyset = set(keynames) + keyset.add('oid') + for n in attnames: + if n not in keyset and n not in generated: + value = kw.get(n, n in row) + if value: + if not isinstance(value, str): + value = f'excluded.{col(n)}' + update.append(f'{col(n)} = {value}') + if not values: + return row + do = 'update set ' + ', '.join(update) if update else 'nothing' + ret = 'oid, *' if qoid else '*' + t = self._escape_qualified_name(table) + cmd = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 + f' VALUES ({values})' + f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if res: # may be empty with "do nothing" + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + else: + self.get(table, row) + return row + + def clear(self, table: str, row: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Clear all the attributes to values determined by the types. + + Numeric types are set to 0, Booleans are set to false, and everything + else is set to the empty string. If the row argument is present, + it is used as the row dictionary and any entries matching attribute + names are cleared with everything else left unchanged. + """ + # At some point we will need a way to get defaults from a table. + if row is None: + row = {} # empty if argument is not present + attnames = self.get_attnames(table) + for n, t in attnames.items(): + if n == 'oid': + continue + t = t.simple + if t in DbTypes._num_types: + row[n] = 0 + elif t == 'bool': + row[n] = self._make_bool(False) + else: + row[n] = '' + return row + + def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> int: + """Delete an existing row in a database table. + + This method deletes the row from a table. It deletes based on the + primary key of the table or the OID value as munged by get() or + passed as keyword. The OID will take precedence if provided. + + The return value is the number of deleted rows (i.e. 0 if the row + did not exist and 1 if the row was deleted). + + Note that if the row cannot be deleted because e.g. it is still + referenced by another table, this method raises a ProgrammingError. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + attnames = self.get_attnames(table) + qoid = oid_key(table) if 'oid' in attnames else None + if row is None: + row = {} + elif 'oid' in row: + del row['oid'] # only accept oid key from named args for safety + row.update(kw) + if qoid and qoid in row and 'oid' not in row: + row['oid'] = row[qoid] + if qoid and 'oid' in row: # try using the oid + keynames: tuple[str, ...] = ('oid',) + else: # try using the primary key + try: + keynames = self.pkeys(table) + except KeyError as e: # the table has no primary key + raise prg_error(f'Table {table} has no primary key') from e + # check whether all key columns have values + if not set(keynames).issubset(row): + raise KeyError('Missing value for primary key in row') + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( + col(k), adapt(row[k], attnames[k])) for k in keynames) + if 'oid' in row: + if qoid: + row[qoid] = row['oid'] + del row['oid'] + t = self._escape_qualified_name(table) + cmd = f'DELETE FROM {t} WHERE {where}' # noqa: S608 + self._do_debug(cmd, params) + res = self._valid_db.query(cmd, params) + return int(res) # type: ignore + + def truncate(self, table: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str], restart: bool = False, + cascade: bool = False, only: bool = False) -> Query: + """Empty a table or set of tables. + + This method quickly removes all rows from the given table or set + of tables. It has the same effect as an unqualified DELETE on each + table, but since it does not actually scan the tables it is faster. + Furthermore, it reclaims disk space immediately, rather than requiring + a subsequent VACUUM operation. This is most useful on large tables. + + If restart is set to True, sequences owned by columns of the truncated + table(s) are automatically restarted. If cascade is set to True, it + also truncates all tables that have foreign-key references to any of + the named tables. If the parameter 'only' is not set to True, all the + descendant tables (if any) will also be truncated. Optionally, a '*' + can be specified after the table name to explicitly indicate that + descendant tables are included. + """ + if isinstance(table, str): + table_only = {table: only} + table = [table] + elif isinstance(table, (list, tuple)): + if isinstance(only, (list, tuple)): + table_only = dict(zip(table, only)) + else: + table_only = dict.fromkeys(table, only) + elif isinstance(table, (set, frozenset)): + table_only = dict.fromkeys(table, only) + else: + raise TypeError('The table must be a string, list or set') + if not (restart is None or isinstance(restart, (bool, int))): + raise TypeError('Invalid type for the restart option') + if not (cascade is None or isinstance(cascade, (bool, int))): + raise TypeError('Invalid type for the cascade option') + tables = [] + for t in table: + u = table_only.get(t) + if not (u is None or isinstance(u, (bool, int))): + raise TypeError('Invalid type for the only option') + if t.endswith('*'): + if u: + raise ValueError( + 'Contradictory table name and only options') + t = t[:-1].rstrip() + t = self._escape_qualified_name(t) + if u: + t = f'ONLY {t}' + tables.append(t) + cmd_parts = ['TRUNCATE', ', '.join(tables)] + if restart: + cmd_parts.append('RESTART IDENTITY') + if cascade: + cmd_parts.append('CASCADE') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + return self._valid_db.query(cmd) + + def get_as_list( + self, table: str, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | bool | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> list: + """Get a table as a list. + + This gets a convenient representation of the table as a list + of named tuples in Python. You only need to pass the name of + the table (or any other SQL expression returning rows). Note that + by default this will return the full content of the table which + can be huge and overflow your memory. However, you can control + the amount of data returned using the other optional parameters. + + The parameter 'what' can restrict the query to only return a + subset of the table columns. It can be a string, list or a tuple. + + The parameter 'where' can restrict the query to only return a + subset of the table rows. It can be a string, list or a tuple + of SQL expressions that all need to be fulfilled. + + The parameter 'order' specifies the ordering of the rows. It can + also be a string, list or a tuple. If no ordering is specified, + the result will be ordered by the primary key(s) or all columns if + no primary key exists. You can set 'order' to False if you don't + care about the ordering. The parameters 'limit' and 'offset' can be + integers specifying the maximum number of rows returned and a number + of rows skipped over. + + If you set the 'scalar' option to True, then instead of the + named tuples you will get the first items of these tuples. + This is useful if the result has only one column anyway. + """ + if not table: + raise TypeError('The table name is missing') + if what: + if isinstance(what, (list, tuple)): + what = ', '.join(map(str, what)) + if order is None: + order = what + else: + what = '*' + cmd_parts = ['SELECT', what, 'FROM', table] + if where: + if isinstance(where, (list, tuple)): + where = ' AND '.join(map(str, where)) + cmd_parts.extend(['WHERE', where]) + if order is None or order is True: + try: + order = self.pkeys(table) + except (KeyError, ProgrammingError): + with suppress(KeyError, ProgrammingError): + order = list(self.get_attnames(table)) + if order and not isinstance(order, bool): + if isinstance(order, (list, tuple)): + order = ', '.join(map(str, order)) + cmd_parts.extend(['ORDER BY', order]) + if limit: + cmd_parts.append(f'LIMIT {limit}') + if offset: + cmd_parts.append(f'OFFSET {offset}') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + query = self._valid_db.query(cmd) + res = query.namedresult() + if res and scalar: + res = [row[0] for row in res] + return res + + def get_as_dict( + self, table: str, + keyname: str | list[str] | tuple[str, ...] | None = None, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | bool | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> dict: + """Get a table as a dictionary. + + This method is similar to get_as_list(), but returns the table + as a Python dict instead of a Python list, which can be even + more convenient. The primary key column(s) of the table will + be used as the keys of the dictionary, while the other column(s) + will be the corresponding values. The keys will be named tuples + if the table has a composite primary key. The rows will be also + named tuples unless the 'scalar' option has been set to True. + With the optional parameter 'keyname' you can specify an alternative + set of columns to be used as the keys of the dictionary. It must + be set as a string, list or a tuple. + + The dictionary will be ordered using the order specified with the + 'order' parameter or the key column(s) if not specified. You can + set 'order' to False if you don't care about the ordering. + """ + if not table: + raise TypeError('The table name is missing') + if not keyname: + try: + keyname = self.pkeys(table) + except (KeyError, ProgrammingError) as e: + raise prg_error(f'Table {table} has no primary key') from e + if isinstance(keyname, str): + keynames: list[str] | tuple[str, ...] = (keyname,) + elif isinstance(keyname, (list, tuple)): + keynames = keyname + else: + raise KeyError('The keyname must be a string, list or tuple') + if what: + if isinstance(what, (list, tuple)): + what = ', '.join(map(str, what)) + if order is None: + order = what + else: + what = '*' + cmd_parts = ['SELECT', what, 'FROM', table] + if where: + if isinstance(where, (list, tuple)): + where = ' AND '.join(map(str, where)) + cmd_parts.extend(['WHERE', where]) + if order is None or order is True: + order = keyname + if order and not isinstance(order, bool): + if isinstance(order, (list, tuple)): + order = ', '.join(map(str, order)) + cmd_parts.extend(['ORDER BY', order]) + if limit: + cmd_parts.append(f'LIMIT {limit}') + if offset: + cmd_parts.append(f'OFFSET {offset}') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + query = self._valid_db.query(cmd) + res = query.getresult() + if not res: + return {} + keyset = set(keynames) + fields = query.listfields() + if not keyset.issubset(fields): + raise KeyError('Missing keyname in row') + key_index: list[int] = [] + row_index: list[int] = [] + for i, f in enumerate(fields): + (key_index if f in keyset else row_index).append(i) + key_tuple = len(key_index) > 1 + get_key = itemgetter(*key_index) + keys = map(get_key, res) + if scalar: + row_index = row_index[:1] + row_is_tuple = False + else: + row_is_tuple = len(row_index) > 1 + if scalar or row_is_tuple: + get_row: Callable[[tuple], tuple] = itemgetter( # pyright: ignore + *row_index) + else: + frst_index = row_index[0] + + def get_row(row : tuple) -> tuple: + return row[frst_index], # tuple with one item + + row_is_tuple = True + rows = map(get_row, res) + if key_tuple or row_is_tuple: + if key_tuple: + keys = namediter(_MemoryQuery(keys, keynames)) # type: ignore + if row_is_tuple: + fields = tuple(f for f in fields if f not in keyset) + rows = namediter(_MemoryQuery(rows, fields)) # type: ignore + # noinspection PyArgumentList + return dict(zip(keys, rows)) + + def notification_handler(self, event: str, callback: Callable, + arg_dict: dict | None = None, + timeout: int | float | None = None, + stop_event: str | None = None + ) -> NotificationHandler: + """Get notification handler that will run the given callback.""" + return NotificationHandler(self, event, callback, + arg_dict, timeout, stop_event) + + +class _MemoryQuery: + """Class that embodies a given query result.""" + + result: Any + fields: tuple[str, ...] + + def __init__(self, result: Any, fields: Sequence[str]) -> None: + """Create query from given result rows and field names.""" + self.result = result + self.fields = tuple(fields) + + def listfields(self) -> tuple[str, ...]: + """Return the stored field names of this query.""" + return self.fields + + def getresult(self) -> Any: + """Return the stored result of this query.""" + return self.result + + def __iter__(self) -> Iterator[Any]: + return iter(self.result) \ No newline at end of file diff --git a/pg/error.py b/pg/error.py new file mode 100644 index 00000000..b3164b42 --- /dev/null +++ b/pg/error.py @@ -0,0 +1,35 @@ +"""Error helpers.""" + +from __future__ import annotations + +from typing import TypeVar + +from .core import DatabaseError, Error, InternalError, ProgrammingError + +__all__ = ['error', 'db_error', 'int_error', 'prg_error'] + +# Error messages + +E = TypeVar('E', bound=Error) + +def error(msg: str, cls: type[E]) -> E: + """Return specified error object with empty sqlstate attribute.""" + error = cls(msg) + if isinstance(error, DatabaseError): + error.sqlstate = None + return error + + +def db_error(msg: str) -> DatabaseError: + """Return DatabaseError.""" + return error(msg, DatabaseError) + + +def int_error(msg: str) -> InternalError: + """Return InternalError.""" + return error(msg, InternalError) + + +def prg_error(msg: str) -> ProgrammingError: + """Return ProgrammingError.""" + return error(msg, ProgrammingError) \ No newline at end of file diff --git a/pg/helpers.py b/pg/helpers.py new file mode 100644 index 00000000..4426cfbc --- /dev/null +++ b/pg/helpers.py @@ -0,0 +1,98 @@ +"""Helper functions.""" + +from __future__ import annotations + +from collections import namedtuple +from decimal import Decimal +from functools import lru_cache +from json import loads as jsondecode +from typing import Any, Callable, Generator, NamedTuple, Sequence + +from .core import Query, set_decimal, set_jsondecode, set_query_helpers + +SomeNamedTuple = Any # alias for accessing arbitrary named tuples + +__all__ = [ + 'quote_if_unqualified', 'oid_key', 'set_row_factory_size', + 'dictiter', 'namediter', 'namednext', 'scalariter' +] + + +# Small helper functions + +def quote_if_unqualified(param: str, name: int | str) -> str: + """Quote parameter representing a qualified name. + + Puts a quote_ident() call around the given parameter unless + the name contains a dot, in which case the name is ambiguous + (could be a qualified name or just a name with a dot in it) + and must be quoted manually by the caller. + """ + if isinstance(name, str) and '.' not in name: + return f'quote_ident({param})' + return param + +def oid_key(table: str) -> str: + """Build oid key from a table name.""" + return f'oid({table})' + + +# Row factory + +# The result rows for database operations are returned as named tuples +# by default. Since creating namedtuple classes is a somewhat expensive +# operation, we cache up to 1024 of these classes by default. + +@lru_cache(maxsize=1024) +def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: + """Get a namedtuple factory for row results with the given names.""" + try: + return namedtuple('Row', names, rename=True)._make # type: ignore + except ValueError: # there is still a problem with the field names + names = [f'column_{n}' for n in range(len(names))] + return namedtuple('Row', names)._make # type: ignore + + +def set_row_factory_size(maxsize: int | None) -> None: + """Change the size of the namedtuple factory cache. + + If maxsize is set to None, the cache can grow without bound. + """ + global _row_factory + _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) + + +# Helper functions used by the query object + +def dictiter(q: Query) -> Generator[dict[str, Any], None, None]: + """Get query result as an iterator of dictionaries.""" + fields: tuple[str, ...] = q.listfields() + for r in q: + yield dict(zip(fields, r)) + + +def namediter(q: Query) -> Generator[SomeNamedTuple, None, None]: + """Get query result as an iterator of named tuples.""" + row = _row_factory(q.listfields()) + for r in q: + yield row(r) + + +def namednext(q: Query) -> SomeNamedTuple: + """Get next row from query result as a named tuple.""" + return _row_factory(q.listfields())(next(q)) + + +def scalariter(q: Query) -> Generator[Any, None, None]: + """Get query result as an iterator of scalar values.""" + for r in q: + yield r[0] + + +# Initialization + +def init_core() -> None: + """Initialize the C extension module.""" + set_decimal(Decimal) + set_jsondecode(jsondecode) + set_query_helpers(dictiter, namediter, namednext, scalariter) diff --git a/pg/notify.py b/pg/notify.py new file mode 100644 index 00000000..e273c521 --- /dev/null +++ b/pg/notify.py @@ -0,0 +1,149 @@ +"""Handling of notifications.""" + +from __future__ import annotations + +import select +from typing import TYPE_CHECKING, Callable + +from .core import Query +from .error import db_error + +if TYPE_CHECKING: + from .db import DB + +__all__ = ['NotificationHandler'] + +# The notification handler + +class NotificationHandler: + """A PostgreSQL client-side asynchronous notification handler.""" + + def __init__(self, db: DB, event: str, callback: Callable, + arg_dict: dict | None = None, + timeout: int | float | None = None, + stop_event: str | None = None): + """Initialize the notification handler. + + You must pass a PyGreSQL database connection, the name of an + event (notification channel) to listen for and a callback function. + + You can also specify a dictionary arg_dict that will be passed as + the single argument to the callback function, and a timeout value + in seconds (a floating point number denotes fractions of seconds). + If it is absent or None, the callers will never time out. If the + timeout is reached, the callback function will be called with a + single argument that is None. If you set the timeout to zero, + the handler will poll notifications synchronously and return. + + You can specify the name of the event that will be used to signal + the handler to stop listening as stop_event. By default, it will + be the event name prefixed with 'stop_'. + """ + self.db: DB | None = db + self.event = event + self.stop_event = stop_event or f'stop_{event}' + self.listening = False + self.callback = callback + if arg_dict is None: + arg_dict = {} + self.arg_dict = arg_dict + self.timeout = timeout + + def __del__(self) -> None: + """Delete the notification handler.""" + self.unlisten() + + def close(self) -> None: + """Stop listening and close the connection.""" + if self.db: + self.unlisten() + self.db.close() + self.db = None + + def listen(self) -> None: + """Start listening for the event and the stop event.""" + db = self.db + if db and not self.listening: + db.query(f'listen "{self.event}"') + db.query(f'listen "{self.stop_event}"') + self.listening = True + + def unlisten(self) -> None: + """Stop listening for the event and the stop event.""" + db = self.db + if db and self.listening: + db.query(f'unlisten "{self.event}"') + db.query(f'unlisten "{self.stop_event}"') + self.listening = False + + def notify(self, db: DB | None = None, stop: bool = False, + payload: str | None = None) -> Query | None: + """Generate a notification. + + Optionally, you can pass a payload with the notification. + + If you set the stop flag, a stop notification will be sent that + will cause the handler to stop listening. + + Note: If the notification handler is running in another thread, you + must pass a different database connection since PyGreSQL database + connections are not thread-safe. + """ + if not self.listening: + return None + if not db: + db = self.db + if not db: + return None + event = self.stop_event if stop else self.event + cmd = f'notify "{event}"' + if payload: + cmd += f", '{payload}'" + return db.query(cmd) + + def __call__(self) -> None: + """Invoke the notification handler. + + The handler is a loop that listens for notifications on the event + and stop event channels. When either of these notifications are + received, its associated 'pid', 'event' and 'extra' (the payload + passed with the notification) are inserted into its arg_dict + dictionary and the callback is invoked with this dictionary as + a single argument. When the handler receives a stop event, it + stops listening to both events and return. + + In the special case that the timeout of the handler has been set + to zero, the handler will poll all events synchronously and return. + If will keep listening until it receives a stop event. + + Note: If you run this loop in another thread, don't use the same + database connection for database operations in the main thread. + """ + if not self.db: + return + self.listen() + poll = self.timeout == 0 + rlist = [] if poll else [self.db.fileno()] + while self.db and self.listening: + # noinspection PyUnboundLocalVariable + if poll or select.select(rlist, [], [], self.timeout)[0]: + while self.db and self.listening: + notice = self.db.getnotify() + if not notice: # no more messages + break + event, pid, extra = notice + if event not in (self.event, self.stop_event): + self.unlisten() + raise db_error( + f'Listening for "{self.event}"' + f' and "{self.stop_event}",' + f' but notified of "{event}"') + if event == self.stop_event: + self.unlisten() + self.arg_dict.update(pid=pid, event=event, extra=extra) + self.callback(self.arg_dict) + if poll: + break + else: # we timed out + self.unlisten() + self.callback(None) \ No newline at end of file diff --git a/pg/tz.py b/pg/tz.py new file mode 100644 index 00000000..7f22e049 --- /dev/null +++ b/pg/tz.py @@ -0,0 +1,21 @@ +"""Timezone helpers.""" + +from __future__ import annotations + +__all__ = ['timezone_as_offset'] + +# time zones used in Postgres timestamptz output +_timezone_offsets: dict[str, str] = { + 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', + 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', + 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' +} + + +def timezone_as_offset(tz: str) -> str: + """Convert timezone abbreviation to offset.""" + if tz.startswith(('+', '-')): + if len(tz) < 5: + return tz + '00' + return tz.replace(':', '') + return _timezone_offsets.get(tz, '+0000') \ No newline at end of file diff --git a/tests/test_classic_attrdict.py b/tests/test_classic_attrdict.py new file mode 100644 index 00000000..8eef00df --- /dev/null +++ b/tests/test_classic_attrdict.py @@ -0,0 +1,100 @@ +#!/usr/bin/python + +"""Test the classic PyGreSQL interface. + +Sub-tests for the DB wrapper object. + +Contributed by Christoph Zwerschke. + +These tests need a database to test against. +""" + +import unittest + +import pg.attrs # the module under test + + +class TestAttrDict(unittest.TestCase): + """Test the simple ordered dictionary for attribute names.""" + + cls = pg.attrs.AttrDict + + def test_init(self): + a = self.cls() + self.assertIsInstance(a, dict) + self.assertEqual(a, {}) + items = [('id', 'int'), ('name', 'text')] + a = self.cls(items) + self.assertIsInstance(a, dict) + self.assertEqual(a, dict(items)) + iteritems = iter(items) + a = self.cls(iteritems) + self.assertIsInstance(a, dict) + self.assertEqual(a, dict(items)) + + def test_iter(self): + a = self.cls() + self.assertEqual(list(a), []) + keys = ['id', 'name', 'age'] + items = [(key, None) for key in keys] + a = self.cls(items) + self.assertEqual(list(a), keys) + + def test_keys(self): + a = self.cls() + self.assertEqual(list(a.keys()), []) + keys = ['id', 'name', 'age'] + items = [(key, None) for key in keys] + a = self.cls(items) + self.assertEqual(list(a.keys()), keys) + + def test_values(self): + a = self.cls() + self.assertEqual(list(a.values()), []) + items = [('id', 'int'), ('name', 'text')] + values = [item[1] for item in items] + a = self.cls(items) + self.assertEqual(list(a.values()), values) + + def test_items(self): + a = self.cls() + self.assertEqual(list(a.items()), []) + items = [('id', 'int'), ('name', 'text')] + a = self.cls(items) + self.assertEqual(list(a.items()), items) + + def test_get(self): + a = self.cls([('id', 1)]) + try: + self.assertEqual(a['id'], 1) + except KeyError: + self.fail('AttrDict should be readable') + + def test_set(self): + a = self.cls() + try: + a['id'] = 1 + except TypeError: + pass + else: + self.fail('AttrDict should be read-only') + + def test_del(self): + a = self.cls([('id', 1)]) + try: + del a['id'] + except TypeError: + pass + else: + self.fail('AttrDict should be read-only') + + def test_write_methods(self): + a = self.cls([('id', 1)]) + self.assertEqual(a['id'], 1) + for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': + method = getattr(a, method) + self.assertRaises(TypeError, method, a) # type: ignore + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index d6a742bf..eca64afd 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2353,7 +2353,7 @@ def test_get_decimal_point(self): self.assertIsNone(r) def test_set_decimal_point(self): - d = pg.Decimal + d = Decimal point = pg.get_decimal_point() self.assertRaises(TypeError, pg.set_decimal_point) # error if decimal point is not a string @@ -2480,7 +2480,7 @@ def test_get_decimal(self): decimal_class = pg.get_decimal() # error if a parameter is passed self.assertRaises(TypeError, pg.get_decimal, decimal_class) - self.assertIs(decimal_class, pg.Decimal) # the default setting + self.assertIs(decimal_class, Decimal) # the default setting pg.set_decimal(int) try: r = pg.get_decimal() @@ -2499,7 +2499,6 @@ def test_set_decimal(self): r = query("select 3425::numeric") except pg.DatabaseError: self.skipTest('database does not support numeric') - r = None r = r.getresult()[0][0] self.assertIsInstance(r, decimal_class) self.assertEqual(r, decimal_class('3425')) @@ -2557,7 +2556,6 @@ def test_set_bool(self): r = query("select true::bool") except pg.ProgrammingError: self.skipTest('database does not support bool') - r = None r = r.getresult()[0][0] self.assertIsInstance(r, bool) self.assertEqual(r, True) @@ -2620,7 +2618,6 @@ def test_set_bytea_escaped(self): r = query("select 'data'::bytea") except pg.ProgrammingError: self.skipTest('database does not support bytea') - r = None r = r.getresult()[0][0] self.assertIsInstance(r, bytes) self.assertEqual(r, b'data') @@ -2653,7 +2650,8 @@ def test_set_row_factory_size(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - info = pg._row_factory.cache_info() + from pg.helpers import _row_factory + info = _row_factory.cache_info() self.assertEqual(info.maxsize, maxsize) self.assertEqual(info.hits + info.misses, 6) self.assertEqual( diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 74d6df8e..8ebb8214 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -49,88 +49,6 @@ def DB(): # noqa: N802 return db -class TestAttrDict(unittest.TestCase): - """Test the simple ordered dictionary for attribute names.""" - - cls = pg.AttrDict - - def test_init(self): - a = self.cls() - self.assertIsInstance(a, dict) - self.assertEqual(a, {}) - items = [('id', 'int'), ('name', 'text')] - a = self.cls(items) - self.assertIsInstance(a, dict) - self.assertEqual(a, dict(items)) - iteritems = iter(items) - a = self.cls(iteritems) - self.assertIsInstance(a, dict) - self.assertEqual(a, dict(items)) - - def test_iter(self): - a = self.cls() - self.assertEqual(list(a), []) - keys = ['id', 'name', 'age'] - items = [(key, None) for key in keys] - a = self.cls(items) - self.assertEqual(list(a), keys) - - def test_keys(self): - a = self.cls() - self.assertEqual(list(a.keys()), []) - keys = ['id', 'name', 'age'] - items = [(key, None) for key in keys] - a = self.cls(items) - self.assertEqual(list(a.keys()), keys) - - def test_values(self): - a = self.cls() - self.assertEqual(list(a.values()), []) - items = [('id', 'int'), ('name', 'text')] - values = [item[1] for item in items] - a = self.cls(items) - self.assertEqual(list(a.values()), values) - - def test_items(self): - a = self.cls() - self.assertEqual(list(a.items()), []) - items = [('id', 'int'), ('name', 'text')] - a = self.cls(items) - self.assertEqual(list(a.items()), items) - - def test_get(self): - a = self.cls([('id', 1)]) - try: - self.assertEqual(a['id'], 1) - except KeyError: - self.fail('AttrDict should be readable') - - def test_set(self): - a = self.cls() - try: - a['id'] = 1 - except TypeError: - pass - else: - self.fail('AttrDict should be read-only') - - def test_del(self): - a = self.cls([('id', 1)]) - try: - del a['id'] - except TypeError: - pass - else: - self.fail('AttrDict should be read-only') - - def test_write_methods(self): - a = self.cls([('id', 1)]) - self.assertEqual(a['id'], 1) - for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': - method = getattr(a, method) - self.assertRaises(TypeError, method, a) # type: ignore - - class TestDBClassInit(unittest.TestCase): """Test proper handling of errors when creating DB instances.""" @@ -491,8 +409,8 @@ def test_class_name(self): self.assertEqual(self.db.__class__.__name__, 'DB') def test_module_name(self): - self.assertEqual(self.db.__module__, 'pg') - self.assertEqual(self.db.__class__.__module__, 'pg') + self.assertEqual(self.db.__module__, 'pg.db') + self.assertEqual(self.db.__class__.__module__, 'pg.db') def test_escape_literal(self): f = self.db.escape_literal @@ -1437,21 +1355,21 @@ def test_get_attnames_is_ordered(self): self.assertEqual(r, 'n alpha v gamma tau beta') def test_get_attnames_is_attr_dict(self): - AttrDict = pg.AttrDict # noqa: N806 + from pg.attrs import AttrDict get_attnames = self.db.get_attnames r = get_attnames('test', flush=True) self.assertIsInstance(r, AttrDict) if self.regtypes: - self.assertEqual(r, AttrDict([ - ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'), - ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'), - ('m', 'money'), ('v4', 'character varying'), - ('c4', 'character'), ('t', 'text')])) + self.assertEqual(r, AttrDict({ + 'i2': 'smallint', 'i4': 'integer', 'i8': 'bigint', + 'd': 'numeric', 'f4': 'real', 'f8': 'double precision', + 'm': 'money', 'v4': 'character varying', + 'c4': 'character', 't': 'text'})) else: - self.assertEqual(r, AttrDict([ - ('i2', 'int'), ('i4', 'int'), ('i8', 'int'), - ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'), - ('v4', 'text'), ('c4', 'text'), ('t', 'text')])) + self.assertEqual(r, AttrDict({ + 'i2': 'int', 'i4': 'int', 'i8': 'int', + 'd': 'num', 'f4': 'float', 'f8': 'float', 'm': 'money', + 'v4': 'text', 'c4': 'text', 't': 'text'})) r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' @@ -1461,14 +1379,14 @@ def test_get_attnames_is_attr_dict(self): r = get_attnames(table) self.assertIsInstance(r, AttrDict) if self.regtypes: - self.assertEqual(r, AttrDict([ - ('n', 'integer'), ('alpha', 'smallint'), - ('v', 'character varying'), ('gamma', 'character'), - ('tau', 'text'), ('beta', 'boolean')])) + self.assertEqual(r, AttrDict({ + 'n': 'integer', 'alpha': 'smallint', + 'v': 'character varying', 'gamma': 'character', + 'tau': 'text', 'beta': 'boolean'})) else: - self.assertEqual(r, AttrDict([ - ('n', 'int'), ('alpha', 'int'), ('v', 'text'), - ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')])) + self.assertEqual(r, AttrDict({ + 'n': 'int', 'alpha': 'int', 'v': 'text', + 'gamma': 'text', 'tau': 'text', 'beta': 'bool'})) r = ' '.join(list(r.keys())) self.assertEqual(r, 'n alpha v gamma tau beta') @@ -4204,7 +4122,8 @@ def test_get_set_type_cast(self): self.assertNotIn('bool', dbtypes) self.assertIs(get_typecast('int4'), int) self.assertIs(get_typecast('float4'), float) - self.assertIs(get_typecast('bool'), pg.cast_bool) + from pg.cast import cast_bool + self.assertIs(get_typecast('bool'), cast_bool) cast_circle = get_typecast('circle') self.addCleanup(set_typecast, 'circle', cast_circle) squared_circle = lambda v: f'Squared Circle: {v}' # noqa: E731 @@ -4416,14 +4335,19 @@ def test_adapt_query_typed_list(self): values = [(3, 7.5, 'hello', True, [123], ['abc'])] t = self.adapter.simple_type typ = t('record') - typ._get_attnames = lambda _self: pg.AttrDict([ - ('i', t('int')), ('f', t('float')), - ('t', t('text')), ('b', t('bool')), - ('i3', t('int[]')), ('t3', t('text[]'))]) + from pg.attrs import AttrDict + typ._get_attnames = lambda _self: AttrDict({ + 'i': t('int'), 'f': t('float'), + 't': t('text'), 'b': t('bool'), + 'i3': t('int[]'), 't3': t('text[]')}) types = [typ] sql, params = format_query('select %s', values, types) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = [(0, -3.25, '', False, [0], [''])] + sql, params = format_query('select %s', values, types) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) def test_adapt_query_typed_list_with_types_as_string(self): format_query = self.adapter.format_query @@ -4527,14 +4451,19 @@ def test_adapt_query_typed_dict(self): values = dict(record=(3, 7.5, 'hello', True, [123], ['abc'])) t = self.adapter.simple_type typ = t('record') - typ._get_attnames = lambda _self: pg.AttrDict([ - ('i', t('int')), ('f', t('float')), - ('t', t('text')), ('b', t('bool')), - ('i3', t('int[]')), ('t3', t('text[]'))]) + from pg.attrs import AttrDict + typ._get_attnames = lambda _self: AttrDict({ + 'i': t('int'), 'f': t('float'), + 't': t('text'), 'b': t('bool'), + 'i3': t('int[]'), 't3': t('text[]')}) types = dict(record=typ) sql, params = format_query('select %(record)s', values, types) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = dict(record=(0, -3.25, '', False, [0], [''])) + sql, params = format_query('select %(record)s', values, types) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) def test_adapt_query_untyped_list(self): format_query = self.adapter.format_query @@ -4560,6 +4489,10 @@ def test_adapt_query_untyped_list(self): sql, params = format_query('select %s', values) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = [(0, -3.25, '', False, [0], [''])] + sql, params = format_query('select %s', values) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) def test_adapt_query_untyped_list_with_json(self): format_query = self.adapter.format_query @@ -4601,6 +4534,10 @@ def test_adapt_query_untyped_dict(self): sql, params = format_query('select %(record)s', values) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = dict(record=(0, -3.25, '', False, [0], [''])) + sql, params = format_query('select %(record)s', values) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) def test_adapt_query_inline_list(self): format_query = self.adapter.format_query @@ -4629,6 +4566,11 @@ def test_adapt_query_inline_list(self): self.assertEqual( sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) + values = [(0, -3.25, '', False, [0], [''])] + sql, params = format_query('select %s', values, inline=True) + self.assertEqual( + sql, "select (0,-3.25,'',false,ARRAY[0],ARRAY[''])") + self.assertEqual(params, []) def test_adapt_query_inline_list_with_json(self): format_query = self.adapter.format_query @@ -4676,6 +4618,11 @@ def test_adapt_query_inline_dict(self): self.assertEqual( sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) + values = dict(record=(0, -3.25, '', False, [0], [''])) + sql, params = format_query('select %(record)s', values, inline=True) + self.assertEqual( + sql, "select (0,-3.25,'',false,ARRAY[0],ARRAY[''])") + self.assertEqual(params, []) def test_adapt_query_with_pg_repr(self): format_query = self.adapter.format_query diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 19214c5d..01ed752e 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -15,6 +15,7 @@ import re import unittest from datetime import timedelta +from decimal import Decimal from typing import Any, Sequence import pg # the module under test @@ -854,15 +855,15 @@ class TestCastInterval(unittest.TestCase): 'P-10M-3DT3H55M5.999993S'))] def test_cast_interval(self): + from pg.cast import cast_interval for result, values in self.intervals: - f = pg.cast_interval years, mons, days, hours, mins, secs, usecs = result days += 365 * years + 30 * mons interval = timedelta( days=days, hours=hours, minutes=mins, seconds=secs, microseconds=usecs) for value in values: - self.assertEqual(f(value), interval) + self.assertEqual(cast_interval(value), interval) class TestEscapeFunctions(unittest.TestCase): @@ -970,10 +971,10 @@ def test_set_decimal_point(self): def test_get_decimal(self): r = pg.get_decimal() - self.assertIs(r, pg.Decimal) + self.assertIs(r, Decimal) def test_set_decimal(self): - decimal_class = pg.Decimal + decimal_class = Decimal try: pg.set_decimal(int) r = pg.get_decimal() From 493c6ee81670e0fb8bcd1f397ab87e66e09c6924 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 5 Sep 2023 15:02:30 +0200 Subject: [PATCH 157/194] Nicer initialization of AttrDict in tests --- tests/test_classic_dbwrapper.py | 48 ++++++++++++++++----------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 8ebb8214..2ddde601 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -1360,16 +1360,16 @@ def test_get_attnames_is_attr_dict(self): r = get_attnames('test', flush=True) self.assertIsInstance(r, AttrDict) if self.regtypes: - self.assertEqual(r, AttrDict({ - 'i2': 'smallint', 'i4': 'integer', 'i8': 'bigint', - 'd': 'numeric', 'f4': 'real', 'f8': 'double precision', - 'm': 'money', 'v4': 'character varying', - 'c4': 'character', 't': 'text'})) + self.assertEqual(r, AttrDict( + i2='smallint', i4='integer', i8='bigint', + d='numeric', f4='real', f8='double precision', + m='money', v4='character varying', + c4='character', t='text')) else: - self.assertEqual(r, AttrDict({ - 'i2': 'int', 'i4': 'int', 'i8': 'int', - 'd': 'num', 'f4': 'float', 'f8': 'float', 'm': 'money', - 'v4': 'text', 'c4': 'text', 't': 'text'})) + self.assertEqual(r, AttrDict( + i2='int', i4='int', i8='int', + d='num', f4='float', f8='float', m='money', + v4='text', c4='text', t='text')) r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' @@ -1379,14 +1379,14 @@ def test_get_attnames_is_attr_dict(self): r = get_attnames(table) self.assertIsInstance(r, AttrDict) if self.regtypes: - self.assertEqual(r, AttrDict({ - 'n': 'integer', 'alpha': 'smallint', - 'v': 'character varying', 'gamma': 'character', - 'tau': 'text', 'beta': 'boolean'})) + self.assertEqual(r, AttrDict( + n='integer', alpha='smallint', + v='character varying', gamma='character', + tau='text', beta='boolean')) else: - self.assertEqual(r, AttrDict({ - 'n': 'int', 'alpha': 'int', 'v': 'text', - 'gamma': 'text', 'tau': 'text', 'beta': 'bool'})) + self.assertEqual(r, AttrDict( + n='int', alpha='int', v='text', + gamma='text', tau='text', beta='bool')) r = ' '.join(list(r.keys())) self.assertEqual(r, 'n alpha v gamma tau beta') @@ -4336,10 +4336,10 @@ def test_adapt_query_typed_list(self): t = self.adapter.simple_type typ = t('record') from pg.attrs import AttrDict - typ._get_attnames = lambda _self: AttrDict({ - 'i': t('int'), 'f': t('float'), - 't': t('text'), 'b': t('bool'), - 'i3': t('int[]'), 't3': t('text[]')}) + typ._get_attnames = lambda _self: AttrDict( + i=t('int'), f=t('float'), + t=t('text'), b=t('bool'), + i3=t('int[]'), t3=t('text[]')) types = [typ] sql, params = format_query('select %s', values, types) self.assertEqual(sql, 'select $1') @@ -4452,10 +4452,10 @@ def test_adapt_query_typed_dict(self): t = self.adapter.simple_type typ = t('record') from pg.attrs import AttrDict - typ._get_attnames = lambda _self: AttrDict({ - 'i': t('int'), 'f': t('float'), - 't': t('text'), 'b': t('bool'), - 'i3': t('int[]'), 't3': t('text[]')}) + typ._get_attnames = lambda _self: AttrDict( + i=t('int'), f=t('float'), + t=t('text'), b=t('bool'), + i3=t('int[]'), t3=t('text[]')) types = dict(record=typ) sql, params = format_query('select %(record)s', values, types) self.assertEqual(sql, 'select $1') From 3ff98e3d29d06a1bf5b9685ceaab2be6c0b05f4d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 5 Sep 2023 22:46:58 +0200 Subject: [PATCH 158/194] Split pgdb package into submodules --- docs/contents/changelog.rst | 7 + pg/__init__.py | 11 +- pg/adapt.py | 2 +- pg/error.py | 27 +- pg/helpers.py | 60 +- pgdb/__init__.py | 1784 +----------------------------- pgdb/adapt.py | 237 ++++ pgdb/cast.py | 581 ++++++++++ pgdb/connect.py | 74 ++ pgdb/connection.py | 156 +++ pgdb/constants.py | 14 + pgdb/cursor.py | 645 +++++++++++ pgdb/typecode.py | 34 + tests/dbapi20.py | 9 +- tests/test_classic.py | 6 +- tests/test_classic_connection.py | 14 +- tests/test_classic_dbwrapper.py | 5 +- tests/test_dbapi20.py | 72 +- 18 files changed, 1917 insertions(+), 1821 deletions(-) create mode 100644 pgdb/adapt.py create mode 100644 pgdb/cast.py create mode 100644 pgdb/connect.py create mode 100644 pgdb/connection.py create mode 100644 pgdb/constants.py create mode 100644 pgdb/cursor.py create mode 100644 pgdb/typecode.py diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 6afc68dd..077893a2 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -5,10 +5,17 @@ Version 6.0 (to be released) ---------------------------- - Removed support for Python versions older than 3.7 (released June 2017) and PostgreSQL older than version 10 (released October 2017). +- Converted the standalone modules `pg` and `pgdb` to packages with + several submodules each. The C extension module is now part of the + `pg` package and wrapped into the pure Python module `pg.core`. +- Added type hints and included a stub file for the C extension module. - Added method `pkeys()` to the `pg.DB` object. - Removed deprecated function `pg.pgnotify()`. - Removed deprecated method `ntuples()` of the `pg.Query` object. - Renamed `pgdb.Type` to `pgdb.DbType` to avoid confusion with `typing.Type`. +- `pg` and `pgdb` now use a shared row factory cache. +- The function `set_row_factory_size()` has been removed. The row cache is now + available as a `RowCache` class with methods `change_size()` and `clear()`. - Modernized code and tools for development, testing, linting and building. Version 5.2.5 (2023-08-28) diff --git a/pg/__init__.py b/pg/__init__.py index e0e1b214..37447c9e 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -95,11 +95,9 @@ version, ) from .db import DB -from .helpers import init_core, set_row_factory_size +from .helpers import RowCache, init_core from .notify import NotificationHandler -__version__ = version - __all__ = [ 'DB', 'Adapter', 'NotificationHandler', 'Typecasts', @@ -110,7 +108,7 @@ 'InvalidResultError', 'MultipleResultsError', 'NoResultError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', - 'Connection', 'Query', + 'Connection', 'Query', 'RowCache', 'INV_READ', 'INV_WRITE', 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', @@ -127,9 +125,10 @@ 'set_datestyle', 'set_decimal', 'set_decimal_point', 'set_defbase', 'set_defhost', 'set_defopt', 'set_defpasswd', 'set_defport', 'set_defuser', - 'set_jsondecode', 'set_query_helpers', - 'set_row_factory_size', 'set_typecast', + 'set_jsondecode', 'set_query_helpers', 'set_typecast', 'version', '__version__', ] +__version__ = version + init_core() diff --git a/pg/adapt.py b/pg/adapt.py index fd4705ae..9cbecaaf 100644 --- a/pg/adapt.py +++ b/pg/adapt.py @@ -1,4 +1,4 @@ -"""Adaption of parameters.""" +"""Adaptation of parameters.""" from __future__ import annotations diff --git a/pg/error.py b/pg/error.py index b3164b42..484a1252 100644 --- a/pg/error.py +++ b/pg/error.py @@ -4,9 +4,18 @@ from typing import TypeVar -from .core import DatabaseError, Error, InternalError, ProgrammingError - -__all__ = ['error', 'db_error', 'int_error', 'prg_error'] +from .core import ( + DatabaseError, + Error, + InterfaceError, + InternalError, + OperationalError, + ProgrammingError, +) + +__all__ = [ + 'error', 'db_error', 'if_error', 'int_error', 'op_error', 'prg_error' +] # Error messages @@ -32,4 +41,14 @@ def int_error(msg: str) -> InternalError: def prg_error(msg: str) -> ProgrammingError: """Return ProgrammingError.""" - return error(msg, ProgrammingError) \ No newline at end of file + return error(msg, ProgrammingError) + + +def if_error(msg: str) -> InterfaceError: + """Return InterfaceError.""" + return error(msg, InterfaceError) + + +def op_error(msg: str) -> OperationalError: + """Return OperationalError.""" + return error(msg, OperationalError) diff --git a/pg/helpers.py b/pg/helpers.py index 4426cfbc..53689f6a 100644 --- a/pg/helpers.py +++ b/pg/helpers.py @@ -13,7 +13,7 @@ SomeNamedTuple = Any # alias for accessing arbitrary named tuples __all__ = [ - 'quote_if_unqualified', 'oid_key', 'set_row_factory_size', + 'quote_if_unqualified', 'oid_key', 'QuoteDict', 'RowCache', 'dictiter', 'namediter', 'namednext', 'scalariter' ] @@ -36,30 +36,50 @@ def oid_key(table: str) -> str: """Build oid key from a table name.""" return f'oid({table})' +class QuoteDict(dict): + """Dictionary with auto quoting of its items. -# Row factory + The quote attribute must be set to the desired quote function. + """ -# The result rows for database operations are returned as named tuples -# by default. Since creating namedtuple classes is a somewhat expensive -# operation, we cache up to 1024 of these classes by default. + quote: Callable[[str], str] -@lru_cache(maxsize=1024) -def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: - """Get a namedtuple factory for row results with the given names.""" - try: - return namedtuple('Row', names, rename=True)._make # type: ignore - except ValueError: # there is still a problem with the field names - names = [f'column_{n}' for n in range(len(names))] - return namedtuple('Row', names)._make # type: ignore + def __getitem__(self, key: str) -> str: + """Get a quoted value.""" + return self.quote(super().__getitem__(key)) -def set_row_factory_size(maxsize: int | None) -> None: - """Change the size of the namedtuple factory cache. +class RowCache: + """Global cache for the named tuples used for table rows. - If maxsize is set to None, the cache can grow without bound. + The result rows for database operations are returned as named tuples + by default. Since creating namedtuple classes is a somewhat expensive + operation, we cache up to 1024 of these classes by default. """ - global _row_factory - _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) + + @staticmethod + @lru_cache(maxsize=1024) + def row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: + """Get a namedtuple factory for row results with the given names.""" + try: + return namedtuple('Row', names, rename=True)._make # type: ignore + except ValueError: # there is still a problem with the field names + names = [f'column_{n}' for n in range(len(names))] + return namedtuple('Row', names)._make # type: ignore + + @classmethod + def clear(cls) -> None: + """Clear the namedtuple factory cache.""" + cls.row_factory.cache_clear() + + @classmethod + def change_size(cls, maxsize: int | None) -> None: + """Change the size of the namedtuple factory cache. + + If maxsize is set to None, the cache can grow without bound. + """ + row_factory = cls.row_factory.__wrapped__ + cls.row_factory = lru_cache(maxsize)(row_factory) # type: ignore # Helper functions used by the query object @@ -73,14 +93,14 @@ def dictiter(q: Query) -> Generator[dict[str, Any], None, None]: def namediter(q: Query) -> Generator[SomeNamedTuple, None, None]: """Get query result as an iterator of named tuples.""" - row = _row_factory(q.listfields()) + row = RowCache.row_factory(q.listfields()) for r in q: yield row(r) def namednext(q: Query) -> SomeNamedTuple: """Get next row from query result as a named tuple.""" - return _row_factory(q.listfields())(next(q)) + return RowCache.row_factory(q.listfields())(next(q)) def scalariter(q: Query) -> Generator[Any, None, None]: diff --git a/pgdb/__init__.py b/pgdb/__init__.py index 74ad38e5..b9a4449a 100644 --- a/pgdb/__init__.py +++ b/pgdb/__init__.py @@ -64,35 +64,7 @@ connection.close() # close the connection """ -from __future__ import annotations - -from collections import namedtuple -from collections.abc import Iterable -from contextlib import suppress -from datetime import date, datetime, time, timedelta, tzinfo -from decimal import Decimal as StdDecimal -from functools import lru_cache, partial -from inspect import signature -from json import dumps as jsonencode -from json import loads as jsondecode -from math import isinf, isnan -from re import compile as regex -from time import localtime -from typing import ( - Any, - Callable, - ClassVar, - Generator, - Mapping, - NamedTuple, - Sequence, - TypeVar, -) -from uuid import UUID as Uuid # noqa: N811 - -# import objects from extension module -from pg import ( - RESULT_DQL, +from pg.core import ( DatabaseError, DataError, Error, @@ -103,20 +75,50 @@ OperationalError, ProgrammingError, Warning, - cast_array, - cast_hstore, - cast_record, - unescape_bytea, version, ) -from pg import ( - Connection as Cnx, # base connection -) -from pg import ( - connect as get_cnx, # get base connection -) -__version__ = version +from .adapt import ( + ARRAY, + BINARY, + BOOL, + DATE, + DATETIME, + FLOAT, + HSTORE, + INTEGER, + INTERVAL, + JSON, + LONG, + MONEY, + NUMBER, + NUMERIC, + RECORD, + ROWID, + SMALLINT, + STRING, + TIME, + TIMESTAMP, + UUID, + Binary, + Date, + DateFromTicks, + DbType, + Hstore, + Interval, + Json, + Literal, + Time, + TimeFromTicks, + Timestamp, + TimestampFromTicks, + Uuid, +) +from .cast import get_typecast, reset_typecast, set_typecast +from .connect import connect +from .connection import Connection +from .constants import apilevel, paramstyle, shortcutmethods, threadsafety +from .cursor import Cursor __all__ = [ 'Connection', 'Cursor', @@ -131,1707 +133,9 @@ 'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError', 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', - 'apilevel', 'connect', 'paramstyle', 'threadsafety', 'get_typecast', 'set_typecast', 'reset_typecast', + 'apilevel', 'connect', 'paramstyle', 'shortcutmethods', 'threadsafety', 'version', '__version__', ] -Decimal: type = StdDecimal - - -# *** Module Constants *** - -# compliant with DB API 2.0 -apilevel = '2.0' - -# module may be shared, but not connections -threadsafety = 1 - -# this module use extended python format codes -paramstyle = 'pyformat' - -# shortcut methods have been excluded from DB API 2 and -# are not recommended by the DB SIG, but they can be handy -shortcutmethods = 1 - - -# *** Internal Type Handling *** - -def get_args(func: Callable) -> list: - return list(signature(func).parameters) - - -# time zones used in Postgres timestamptz output -_timezones: dict[str, str] = { - 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', - 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', - 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' -} - - -def _timezone_as_offset(tz: str) -> str: - if tz.startswith(('+', '-')): - if len(tz) < 5: - return tz + '00' - return tz.replace(':', '') - return _timezones.get(tz, '+0000') - - -def decimal_type(decimal_type: type | None = None) -> type: - """Get or set global type to be used for decimal values. - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call con.type_cache.reset_typecast(). - """ - global Decimal - if decimal_type is not None: - Decimal = decimal_type - set_typecast('numeric', decimal_type) - return Decimal - - -def cast_bool(value: str) -> bool | None: - """Cast boolean value in database format to bool.""" - return value[0] in ('t', 'T') if value else None - - -def cast_money(value: str) -> StdDecimal | None: - """Cast money value in database format to Decimal.""" - if not value: - return None - value = value.replace('(', '-') - return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) - - -def cast_int2vector(value: str) -> list[int]: - """Cast an int2vector value.""" - return [int(v) for v in value.split()] - - -def cast_date(value: str, cnx: Cnx) -> date: - """Cast a date value.""" - # The output format depends on the server setting DateStyle. The default - # setting ISO and the setting for German are actually unambiguous. The - # order of days and months in the other two settings is however ambiguous, - # so at least here we need to consult the setting to properly parse values. - if value == '-infinity': - return date.min - if value == 'infinity': - return date.max - values = value.split() - if values[-1] == 'BC': - return date.min - value = values[0] - if len(value) > 10: - return date.max - format = cnx.date_format() - return datetime.strptime(value, format).date() - - -def cast_time(value: str) -> time: - """Cast a time value.""" - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - return datetime.strptime(value, fmt).time() - - -_re_timezone = regex('(.*)([+-].*)') - - -def cast_timetz(value: str) -> time: - """Cast a timetz value.""" - m = _re_timezone.match(value) - if m: - value, tz = m.groups() - else: - tz = '+0000' - format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - value += _timezone_as_offset(tz) - format += '%z' - return datetime.strptime(value, format).timetz() - - -def cast_timestamp(value: str, cnx: Cnx) -> datetime: - """Cast a timestamp value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - values = value.split() - if values[-1] == 'BC': - return datetime.min - format = cnx.date_format() - if format.endswith('-%Y') and len(values) > 2: - values = values[1:5] - if len(values[3]) > 4: - return datetime.max - formats = ['%d %b' if format.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] - else: - if len(values[0]) > 10: - return datetime.max - formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] - return datetime.strptime(' '.join(values), ' '.join(formats)) - - -def cast_timestamptz(value: str, cnx: Cnx) -> datetime: - """Cast a timestamptz value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - values = value.split() - if values[-1] == 'BC': - return datetime.min - format = cnx.date_format() - if format.endswith('-%Y') and len(values) > 2: - values = values[1:] - if len(values[3]) > 4: - return datetime.max - formats = ['%d %b' if format.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] - values, tz = values[:-1], values[-1] - else: - if format.startswith('%Y-'): - m = _re_timezone.match(values[1]) - if m: - values[1], tz = m.groups() - else: - tz = '+0000' - else: - values, tz = values[:-1], values[-1] - if len(values[0]) > 10: - return datetime.max - formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] - values.append(_timezone_as_offset(tz)) - formats.append('%z') - return datetime.strptime(' '.join(values), ' '.join(formats)) - - -_re_interval_sql_standard = regex( - '(?:([+-])?([0-9]+)-([0-9]+) ?)?' - '(?:([+-]?[0-9]+)(?!:) ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres = regex( - '(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres_verbose = regex( - '@ ?(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-]?[0-9]+) ?hours? ?)?' - '(?:([+-]?[0-9]+) ?mins? ?)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') - -_re_interval_iso_8601 = regex( - 'P(?:([+-]?[0-9]+)Y)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-]?[0-9]+)D)?' - '(?:T(?:([+-]?[0-9]+)H)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') - - -def cast_interval(value: str) -> timedelta: - """Cast an interval value.""" - # The output format depends on the server setting IntervalStyle, but it's - # not necessary to consult this setting to parse it. It's faster to just - # check all possible formats, and there is no ambiguity here. - m = _re_interval_iso_8601.match(value) - if m: - s = [v or '0' for v in m.groups()] - secs_ago = s.pop(5) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if secs_ago: - secs = -secs - usecs = -usecs - else: - m = _re_interval_postgres_verbose.match(value) - if m: - s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) - secs_ago = s.pop(5) == '-' - d = [-int(v) for v in s] if ago else [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if secs_ago: - secs = - secs - usecs = -usecs - else: - m = _re_interval_postgres.match(value) - if m and any(m.groups()): - s = [v or '0' for v in m.groups()] - hours_ago = s.pop(3) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - m = _re_interval_sql_standard.match(value) - if m and any(m.groups()): - s = [v or '0' for v in m.groups()] - years_ago = s.pop(0) == '-' - hours_ago = s.pop(3) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if years_ago: - years = -years - mons = -mons - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - raise ValueError(f'Cannot parse interval: {value}') - days += 365 * years + 30 * mons - return timedelta(days=days, hours=hours, minutes=mins, - seconds=secs, microseconds=usecs) - - -class Typecasts(dict): - """Dictionary mapping database types to typecast functions. - - The cast functions get passed the string representation of a value in - the database which they need to convert to a Python object. The - passed string will never be None since NULL values are already - handled before the cast function is called. - """ - - # the default cast functions - # (str functions are ignored but have been added for faster access) - defaults: ClassVar[dict[str, Callable]] = { - 'char': str, 'bpchar': str, 'name': str, - 'text': str, 'varchar': str, 'sql_identifier': str, - 'bool': cast_bool, 'bytea': unescape_bytea, - 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, - 'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode, - 'float4': float, 'float8': float, - 'numeric': Decimal, 'money': cast_money, - 'date': cast_date, 'interval': cast_interval, - 'time': cast_time, 'timetz': cast_timetz, - 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, - 'int2vector': cast_int2vector, 'uuid': Uuid, - 'anyarray': cast_array, 'record': cast_record} - - cnx: Cnx | None = None # for local connection specific instances - - def __missing__(self, typ: str) -> Callable | None: - """Create a cast function if it is not cached. - - Note that this class never raises a KeyError, - but returns None when no special cast function exists. - """ - if not isinstance(typ, str): - raise TypeError(f'Invalid type: {typ}') - cast = self.defaults.get(typ) - if cast: - # store default for faster access - cast = self._add_connection(cast) - self[typ] = cast - elif typ.startswith('_'): - # create array cast - base_cast = self[typ[1:]] - cast = self.create_array_cast(base_cast) - if base_cast: - # store only if base type exists - self[typ] = cast - return cast - - @staticmethod - def _needs_connection(func: Callable) -> bool: - """Check if a typecast function needs a connection argument.""" - try: - args = get_args(func) - except (TypeError, ValueError): - return False - return 'cnx' in args[1:] - - def _add_connection(self, cast: Callable) -> Callable: - """Add a connection argument to the typecast function if necessary.""" - if not self.cnx or not self._needs_connection(cast): - return cast - return partial(cast, cnx=self.cnx) - - def get(self, typ: str, default: Callable | None = None # type: ignore - ) -> Callable | None: - """Get the typecast function for the given database type.""" - return self[typ] or default - - def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: - """Set a typecast function for the specified database type(s).""" - if isinstance(typ, str): - typ = [typ] - if cast is None: - for t in typ: - self.pop(t, None) - self.pop(f'_{t}', None) - else: - if not callable(cast): - raise TypeError("Cast parameter must be callable") - for t in typ: - self[t] = self._add_connection(cast) - self.pop(f'_{t}', None) - - def reset(self, typ: str | Sequence[str] | None = None) -> None: - """Reset the typecasts for the specified type(s) to their defaults. - - When no type is specified, all typecasts will be reset. - """ - defaults = self.defaults - if typ is None: - self.clear() - self.update(defaults) - else: - if isinstance(typ, str): - typ = [typ] - for t in typ: - cast = defaults.get(t) - if cast: - self[t] = self._add_connection(cast) - t = f'_{t}' - cast = defaults.get(t) - if cast: - self[t] = self._add_connection(cast) - else: - self.pop(t, None) - else: - self.pop(t, None) - self.pop(f'_{t}', None) - - def create_array_cast(self, basecast: Callable) -> Callable: - """Create an array typecast for the given base cast.""" - cast_array = self['anyarray'] - - def cast(v: Any) -> list: - return cast_array(v, basecast) - return cast - - def create_record_cast(self, name: str, fields: Sequence[str], - casts: Sequence[str]) -> Callable: - """Create a named record typecast for the given fields and casts.""" - cast_record = self['record'] - record = namedtuple(name, fields) # type: ignore - - def cast(v: Any) -> record: - # noinspection PyArgumentList - return record(*cast_record(v, casts)) - return cast - - -_typecasts = Typecasts() # this is the global typecast dictionary - - -def get_typecast(typ: str) -> Callable | None: - """Get the global typecast function for the given database type.""" - return _typecasts.get(typ) - - -def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: - """Set a global typecast function for the given database type(s). - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call con.type_cache.reset_typecast(). - """ - _typecasts.set(typ, cast) - - -def reset_typecast(typ: str | Sequence[str] | None = None) -> None: - """Reset the global typecasts for the given type(s) to their default. - - When no type is specified, all typecasts will be reset. - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call con.type_cache.reset_typecast(). - """ - _typecasts.reset(typ) - - -class LocalTypecasts(Typecasts): - """Map typecasts, including local composite types, to cast functions.""" - - defaults = _typecasts - - cnx: Cnx | None = None # set in connection specific instances - - def __missing__(self, typ: str) -> Callable | None: - """Create a cast function if it is not cached.""" - cast: Callable | None - if typ.startswith('_'): - base_cast = self[typ[1:]] - cast = self.create_array_cast(base_cast) - if base_cast: - self[typ] = cast - else: - cast = self.defaults.get(typ) - if cast: - cast = self._add_connection(cast) - self[typ] = cast - else: - fields = self.get_fields(typ) - if fields: - casts = [self[field.type] for field in fields] - field_names = [field.name for field in fields] - cast = self.create_record_cast(typ, field_names, casts) - self[typ] = cast - return cast - - # noinspection PyMethodMayBeStatic,PyUnusedLocal - def get_fields(self, typ: str) -> list[FieldInfo]: - """Return the fields for the given record type. - - This method will be replaced with a method that looks up the fields - using the type cache of the connection. - """ - return [] - - -class TypeCode(str): - """Class representing the type_code used by the DB-API 2.0. - - TypeCode objects are strings equal to the PostgreSQL type name, - but carry some additional information. - """ - - oid: int - len: int - type: str - category: str - delim: str - relid: int - - # noinspection PyShadowingBuiltins - @classmethod - def create(cls, oid: int, name: str, len: int, type: str, category: str, - delim: str, relid: int) -> TypeCode: - """Create a type code for a PostgreSQL data type.""" - self = cls(name) - self.oid = oid - self.len = len - self.type = type - self.category = category - self.delim = delim - self.relid = relid - return self - - -FieldInfo = namedtuple('FieldInfo', ('name', 'type')) - - -class TypeCache(dict): - """Cache for database types. - - This cache maps type OIDs and names to TypeCode strings containing - important information on the associated database type. - """ - - def __init__(self, cnx: Cnx) -> None: - """Initialize type cache for connection.""" - super().__init__() - self._escape_string = cnx.escape_string - self._src = cnx.source() - self._typecasts = LocalTypecasts() - self._typecasts.get_fields = self.get_fields # type: ignore - self._typecasts.cnx = cnx - self._query_pg_type = ( - "SELECT oid, typname," - " typlen, typtype, typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) {}") - - def __missing__(self, key: int | str) -> TypeCode: - """Get the type info from the database if it is not cached.""" - oid: int | str - if isinstance(key, int): - oid = key - else: - if '.' not in key and '"' not in key: - key = f'"{key}"' - oid = f"'{self._escape_string(key)}'::pg_catalog.regtype" - try: - self._src.execute(self._query_pg_type.format(oid)) - except ProgrammingError: - res = None - else: - res = self._src.fetch(1) - if not res: - raise KeyError(f'Type {key} could not be found') - r = res[0] - type_code = TypeCode.create( - int(r[0]), r[1], int(r[2]), r[3], r[4], r[5], int(r[6])) - # noinspection PyUnresolvedReferences - self[type_code.oid] = self[str(type_code)] = type_code - return type_code - - def get(self, key: int | str, # type: ignore - default: TypeCode | None = None) -> TypeCode | None: - """Get the type even if it is not cached.""" - try: - return self[key] - except KeyError: - return default - - def get_fields(self, typ: int | str | TypeCode) -> list[FieldInfo] | None: - """Get the names and types of the fields of composite types.""" - if isinstance(typ, TypeCode): - relid = typ.relid - else: - type_code = self.get(typ) - if not type_code: - return None - relid = type_code.relid - if not relid: - return None # this type is not composite - self._src.execute( - "SELECT attname, atttypid" # noqa: S608 - " FROM pg_catalog.pg_attribute" - f" WHERE attrelid OPERATOR(pg_catalog.=) {relid}" - " AND attnum OPERATOR(pg_catalog.>) 0" - " AND NOT attisdropped ORDER BY attnum") - return [FieldInfo(name, self.get(int(oid))) - for name, oid in self._src.fetch(-1)] - - def get_typecast(self, typ: str) -> Callable | None: - """Get the typecast function for the given database type.""" - return self._typecasts[typ] - - def set_typecast(self, typ: str | Sequence[str], - cast: Callable | None) -> None: - """Set a typecast function for the specified database type(s).""" - self._typecasts.set(typ, cast) - - def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: - """Reset the typecast function for the specified database type(s).""" - self._typecasts.reset(typ) - - def typecast(self, value: Any, typ: str) -> Any: - """Cast the given value according to the given database type.""" - if value is None: - # for NULL values, no typecast is necessary - return None - cast = self._typecasts[typ] - if cast is None or cast is str: - # no typecast is necessary - return value - return cast(value) - - def get_row_caster(self, types: Sequence[str]) -> Callable: - """Get a typecast function for a complete row of values.""" - typecasts = self._typecasts - casts = [typecasts[typ] for typ in types] - casts = [cast if cast is not str else None for cast in casts] - - def row_caster(row: Sequence) -> Sequence: - return [value if cast is None or value is None else cast(value) - for cast, value in zip(casts, row)] - - return row_caster - - -class _QuoteDict(dict): - """Dictionary with auto quoting of its items. - - The quote attribute must be set to the desired quote function. - """ - - quote: Callable[[str], str] - - def __getitem__(self, key: str) -> str: - # noinspection PyUnresolvedReferences - return self.quote(super().__getitem__(key)) - - -# *** Error Messages *** - -E = TypeVar('E', bound=Error) - - -def _error(msg: str, cls: type[E]) -> E: - """Return specified error object with empty sqlstate attribute.""" - error = cls(msg) - if isinstance(error, DatabaseError): - error.sqlstate = None - return error - - -def _db_error(msg: str) -> DatabaseError: - """Return DatabaseError.""" - return _error(msg, DatabaseError) - - -def _if_error(msg: str) -> InterfaceError: - """Return InterfaceError.""" - return _error(msg, InterfaceError) - - -def _op_error(msg: str) -> OperationalError: - """Return OperationalError.""" - return _error(msg, OperationalError) - - -# *** Row Tuples *** - -# The result rows for database operations are returned as named tuples -# by default. Since creating namedtuple classes is a somewhat expensive -# operation, we cache up to 1024 of these classes by default. - -# noinspection PyUnresolvedReferences -@lru_cache(maxsize=1024) -def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: - """Get a namedtuple factory for row results with the given names.""" - try: - return namedtuple('Row', names, rename=True)._make # type: ignore - except ValueError: # there is still a problem with the field names - names = [f'column_{n}' for n in range(len(names))] - return namedtuple('Row', names)._make # type: ignore - - -def set_row_factory_size(maxsize: int | None) -> None: - """Change the size of the namedtuple factory cache. - - If maxsize is set to None, the cache can grow without bound. - """ - # noinspection PyGlobalUndefined - global _row_factory - _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) - - -# *** Cursor Object *** - -class Cursor: - """Cursor object.""" - - def __init__(self, connection: Connection) -> None: - """Create a cursor object for the database connection.""" - self.connection = self._connection = connection - cnx = connection._cnx - if not cnx: - raise _op_error("Connection has been closed") - self._cnx: Cnx = cnx - self.type_cache: TypeCache = connection.type_cache - self._src = self._cnx.source() - # the official attribute for describing the result columns - self._description: list[CursorDescription] | bool | None = None - if self.row_factory is Cursor.row_factory: - # the row factory needs to be determined dynamically - self.row_factory = None # type: ignore - else: - self.build_row_factory = None # type: ignore - self.rowcount: int | None = -1 - self.arraysize: int = 1 - self.lastrowid: int | None = None - - def __iter__(self) -> Cursor: - """Make cursor compatible to the iteration protocol.""" - return self - - def __enter__(self) -> Cursor: - """Enter the runtime context for the cursor object.""" - return self - - def __exit__(self, et: type[BaseException] | None, - ev: BaseException | None, tb: Any) -> None: - """Exit the runtime context for the cursor object.""" - self.close() - - def _quote(self, value: Any) -> Any: - """Quote value depending on its type.""" - if value is None: - return 'NULL' - if isinstance(value, (Hstore, Json)): - value = str(value) - if isinstance(value, (bytes, str)): - cnx = self._cnx - if isinstance(value, Binary): - value = cnx.escape_bytea(value).decode('ascii') - else: - value = cnx.escape_string(value) - return f"'{value}'" - if isinstance(value, float): - if isinf(value): - return "'-Infinity'" if value < 0 else "'Infinity'" - if isnan(value): - return "'NaN'" - return value - if isinstance(value, (int, Decimal, Literal)): - return value - if isinstance(value, datetime): - if value.tzinfo: - return f"'{value}'::timestamptz" - return f"'{value}'::timestamp" - if isinstance(value, date): - return f"'{value}'::date" - if isinstance(value, time): - if value.tzinfo: - return f"'{value}'::timetz" - return f"'{value}'::time" - if isinstance(value, timedelta): - return f"'{value}'::interval" - if isinstance(value, Uuid): - return f"'{value}'::uuid" - if isinstance(value, list): - # Quote value as an ARRAY constructor. This is better than using - # an array literal because it carries the information that this is - # an array and not a string. One issue with this syntax is that - # you need to add an explicit typecast when passing empty arrays. - # The ARRAY keyword is actually only necessary at the top level. - if not value: # exception for empty array - return "'{}'" - q = self._quote - v = ','.join(str(q(v)) for v in value) - return f'ARRAY[{v}]' - if isinstance(value, tuple): - # Quote as a ROW constructor. This is better than using a record - # literal because it carries the information that this is a record - # and not a string. We don't use the keyword ROW in order to make - # this usable with the IN syntax as well. It is only necessary - # when the records has a single column which is not really useful. - q = self._quote - v = ','.join(str(q(v)) for v in value) - return f'({v})' - try: # noinspection PyUnresolvedReferences - value = value.__pg_repr__() - except AttributeError as e: - raise InterfaceError( - f'Do not know how to adapt type {type(value)}') from e - if isinstance(value, (tuple, list)): - value = self._quote(value) - return value - - def _quoteparams(self, string: str, - parameters: Mapping | Sequence | None) -> str: - """Quote parameters. - - This function works for both mappings and sequences. - - The function should be used even when there are no parameters, - so that we have a consistent behavior regarding percent signs. - """ - if not parameters: - try: - return string % () # unescape literal quotes if possible - except (TypeError, ValueError): - return string # silently accept unescaped quotes - if isinstance(parameters, dict): - parameters = _QuoteDict(parameters) - parameters.quote = self._quote - else: - parameters = tuple(map(self._quote, parameters)) - return string % parameters - - def _make_description(self, info: tuple[int, str, int, int, int] - ) -> CursorDescription: - """Make the description tuple for the given field info.""" - name, typ, size, mod = info[1:] - type_code = self.type_cache[typ] - if mod > 0: - mod -= 4 - precision: int | None - scale: int | None - if type_code == 'numeric': - precision, scale = mod >> 16, mod & 0xffff - size = precision - else: - if not size: - size = type_code.size - if size == -1: - size = mod - precision = scale = None - return CursorDescription( - name, type_code, None, size, precision, scale, None) - - @property - def description(self) -> list[CursorDescription] | None: - """Read-only attribute describing the result columns.""" - description = self._description - if description is None: - return None - if not isinstance(description, list): - make = self._make_description - description = [make(info) for info in self._src.listinfo()] - self._description = description - return description - - @property - def colnames(self) -> Sequence[str] | None: - """Unofficial convenience method for getting the column names.""" - description = self.description - return None if description is None else [d[0] for d in description] - - @property - def coltypes(self) -> Sequence[TypeCode] | None: - """Unofficial convenience method for getting the column types.""" - description = self.description - return None if description is None else [d[1] for d in description] - - def close(self) -> None: - """Close the cursor object.""" - self._src.close() - - def execute(self, operation: str, parameters: Sequence | None = None - ) -> Cursor: - """Prepare and execute a database operation (query or command).""" - # The parameters may also be specified as list of tuples to e.g. - # insert multiple rows in a single operation, but this kind of - # usage is deprecated. We make several plausibility checks because - # tuples can also be passed with the meaning of ROW constructors. - if (parameters and isinstance(parameters, list) - and len(parameters) > 1 - and all(isinstance(p, tuple) for p in parameters) - and all(len(p) == len(parameters[0]) for p in parameters[1:])): - return self.executemany(operation, parameters) - # not a list of tuples - return self.executemany(operation, [parameters]) - - def executemany(self, operation: str, - seq_of_parameters: Sequence[Sequence | None]) -> Cursor: - """Prepare operation and execute it against a parameter sequence.""" - if not seq_of_parameters: - # don't do anything without parameters - return self - self._description = None - self.rowcount = -1 - # first try to execute all queries - rowcount = 0 - sql = "BEGIN" - try: - if not self._connection._tnx and not self._connection.autocommit: - try: - self._src.execute(sql) - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't start transaction") from e - else: - self._connection._tnx = True - for parameters in seq_of_parameters: - sql = operation - sql = self._quoteparams(sql, parameters) - rows = self._src.execute(sql) - if rows: # true if not DML - rowcount += rows - else: - self.rowcount = -1 - except DatabaseError: - raise # database provides error message - except Error as err: - # noinspection PyTypeChecker - raise _if_error(f"Error in '{sql}': '{err}'") from err - except Exception as err: - raise _op_error(f"Internal error in '{sql}': {err}") from err - # then initialize result raw count and description - if self._src.resulttype == RESULT_DQL: - self._description = True # fetch on demand - self.rowcount = self._src.ntuples - self.lastrowid = None - build_row_factory = self.build_row_factory - if build_row_factory: # type: ignore - self.row_factory = build_row_factory() # type: ignore - else: - self.rowcount = rowcount - self.lastrowid = self._src.oidstatus() - # return the cursor object, so you can write statements such as - # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)" - return self - - def fetchone(self) -> Sequence | None: - """Fetch the next row of a query result set.""" - res = self.fetchmany(1, False) - try: - return res[0] - except IndexError: - return None - - def fetchall(self) -> Sequence[Sequence]: - """Fetch all (remaining) rows of a query result.""" - return self.fetchmany(-1, False) - - def fetchmany(self, size: int | None = None, keep: bool = False - ) -> Sequence[Sequence]: - """Fetch the next set of rows of a query result. - - The number of rows to fetch per call is specified by the - size parameter. If it is not given, the cursor's arraysize - determines the number of rows to be fetched. If you set - the keep parameter to true, this is kept as new arraysize. - """ - if size is None: - size = self.arraysize - if keep: - self.arraysize = size - try: - result = self._src.fetch(size) - except DatabaseError: - raise - except Error as err: - raise _db_error(str(err)) from err - row_factory = self.row_factory - coltypes = self.coltypes - if coltypes is None: - # cannot determine column types, return raw result - return [row_factory(row) for row in result] - if len(result) > 5: - # optimize the case where we really fetch many values - # by looking up all type casting functions upfront - cast_row = self.type_cache.get_row_caster(coltypes) - return [row_factory(cast_row(row)) for row in result] - cast_value = self.type_cache.typecast - return [row_factory([cast_value(value, typ) - for typ, value in zip(coltypes, row)]) for row in result] - - def callproc(self, procname: str, parameters: Sequence | None = None - ) -> Sequence | None: - """Call a stored database procedure with the given name. - - The sequence of parameters must contain one entry for each input - argument that the procedure expects. The result of the call is the - same as this input sequence; replacement of output and input/output - parameters in the return value is currently not supported. - - The procedure may also provide a result set as output. These can be - requested through the standard fetch methods of the cursor. - """ - n = len(parameters) if parameters else 0 - s = ','.join(n * ['%s']) - query = f'select * from "{procname}"({s})' # noqa: S608 - self.execute(query, parameters) - return parameters - - # noinspection PyShadowingBuiltins - def copy_from(self, stream: Any, table: str, - format: str | None = None, sep: str | None = None, - null: str | None = None, size: int | None = None, - columns: Sequence[str] | None = None) -> Cursor: - """Copy data from an input stream to the specified table. - - The input stream can be a file-like object with a read() method or - it can also be an iterable returning a row or multiple rows of input - on each iteration. - - The format must be 'text', 'csv' or 'binary'. The sep option sets the - column separator (delimiter) used in the non binary formats. - The null option sets the textual representation of NULL in the input. - - The size option sets the size of the buffer used when reading data - from file-like objects. - - The copy operation can be restricted to a subset of columns. If no - columns are specified, all of them will be copied. - """ - binary_format = format == 'binary' - try: - read = stream.read - except AttributeError as e: - if size: - raise ValueError( - "Size must only be set for file-like objects") from e - input_type: type | tuple[type, ...] - type_name: str - if binary_format: - input_type = bytes - type_name = 'byte strings' - else: - input_type = (bytes, str) - type_name = 'strings' - - if isinstance(stream, (bytes, str)): - if not isinstance(stream, input_type): - raise ValueError(f"The input must be {type_name}") from e - if not binary_format: - if isinstance(stream, str): - if not stream.endswith('\n'): - stream += '\n' - else: - if not stream.endswith(b'\n'): - stream += b'\n' - - def chunks() -> Generator: - yield stream - - elif isinstance(stream, Iterable): - - def chunks() -> Generator: - for chunk in stream: - if not isinstance(chunk, input_type): - raise ValueError( - f"Input stream must consist of {type_name}") - if isinstance(chunk, str): - if not chunk.endswith('\n'): - chunk += '\n' - else: - if not chunk.endswith(b'\n'): - chunk += b'\n' - yield chunk - - else: - raise TypeError("Need an input stream to copy from") from e - else: - if size is None: - size = 8192 - elif not isinstance(size, int): - raise TypeError("The size option must be an integer") - if size > 0: - - def chunks() -> Generator: - while True: - buffer = read(size) - yield buffer - if not buffer or len(buffer) < size: - break - - else: - - def chunks() -> Generator: - yield read() - - if not table or not isinstance(table, str): - raise TypeError("Need a table to copy to") - if table.lower().startswith('select '): - raise ValueError("Must specify a table, not a query") - cnx = self._cnx - table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) - operation_parts = [f'copy {table}'] - options = [] - parameters = [] - if format is not None: - if not isinstance(format, str): - raise TypeError("The format option must be be a string") - if format not in ('text', 'csv', 'binary'): - raise ValueError("Invalid format") - options.append(f'format {format}') - if sep is not None: - if not isinstance(sep, str): - raise TypeError("The sep option must be a string") - if format == 'binary': - raise ValueError( - "The sep option is not allowed with binary format") - if len(sep) != 1: - raise ValueError( - "The sep option must be a single one-byte character") - options.append('delimiter %s') - parameters.append(sep) - if null is not None: - if not isinstance(null, str): - raise TypeError("The null option must be a string") - options.append('null %s') - parameters.append(null) - if columns: - if not isinstance(columns, str): - columns = ','.join(map(cnx.escape_identifier, columns)) - operation_parts.append(f'({columns})') - operation_parts.append("from stdin") - if options: - operation_parts.append(f"({','.join(options)})") - operation = ' '.join(operation_parts) - - putdata = self._src.putdata - self.execute(operation, parameters) - - try: - for chunk in chunks(): - putdata(chunk) - except BaseException as error: - self.rowcount = -1 - # the following call will re-raise the error - putdata(error) - else: - rowcount = putdata(None) - self.rowcount = -1 if rowcount is None else rowcount - - # return the cursor object, so you can chain operations - return self - - # noinspection PyShadowingBuiltins - def copy_to(self, stream: Any, table: str, - format: str | None = None, sep: str | None = None, - null: str | None = None, decode: bool | None = None, - columns: Sequence[str] | None = None) -> Cursor | Generator: - """Copy data from the specified table to an output stream. - - The output stream can be a file-like object with a write() method or - it can also be None, in which case the method will return a generator - yielding a row on each iteration. - - Output will be returned as byte strings unless you set decode to true. - - Note that you can also use a select query instead of the table name. - - The format must be 'text', 'csv' or 'binary'. The sep option sets the - column separator (delimiter) used in the non binary formats. - The null option sets the textual representation of NULL in the output. - - The copy operation can be restricted to a subset of columns. If no - columns are specified, all of them will be copied. - """ - binary_format = format == 'binary' - if stream is None: - write = None - else: - try: - write = stream.write - except AttributeError as e: - raise TypeError("Need an output stream to copy to") from e - if not table or not isinstance(table, str): - raise TypeError("Need a table to copy to") - cnx = self._cnx - if table.lower().startswith('select '): - if columns: - raise ValueError("Columns must be specified in the query") - table = f'({table})' - else: - table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) - operation_parts = [f'copy {table}'] - options = [] - parameters = [] - if format is not None: - if not isinstance(format, str): - raise TypeError("The format option must be a string") - if format not in ('text', 'csv', 'binary'): - raise ValueError("Invalid format") - options.append(f'format {format}') - if sep is not None: - if not isinstance(sep, str): - raise TypeError("The sep option must be a string") - if binary_format: - raise ValueError( - "The sep option is not allowed with binary format") - if len(sep) != 1: - raise ValueError( - "The sep option must be a single one-byte character") - options.append('delimiter %s') - parameters.append(sep) - if null is not None: - if not isinstance(null, str): - raise TypeError("The null option must be a string") - options.append('null %s') - parameters.append(null) - if decode is None: - decode = format != 'binary' - else: - if not isinstance(decode, (int, bool)): - raise TypeError("The decode option must be a boolean") - if decode and binary_format: - raise ValueError( - "The decode option is not allowed with binary format") - if columns: - if not isinstance(columns, str): - columns = ','.join(map(cnx.escape_identifier, columns)) - operation_parts.append(f'({columns})') - - operation_parts.append("to stdout") - if options: - operation_parts.append(f"({','.join(options)})") - operation = ' '.join(operation_parts) - - getdata = self._src.getdata - self.execute(operation, parameters) - - def copy() -> Generator: - self.rowcount = 0 - while True: - row = getdata(decode) - if isinstance(row, int): - if self.rowcount != row: - self.rowcount = row - break - self.rowcount += 1 - yield row - - if write is None: - # no input stream, return the generator - return copy() - - # write the rows to the file-like input stream - for row in copy(): - # noinspection PyUnboundLocalVariable - write(row) - - # return the cursor object, so you can chain operations - return self - - def __next__(self) -> Sequence: - """Return the next row (support for the iteration protocol).""" - res = self.fetchone() - if res is None: - raise StopIteration - return res - - # Note that the iterator protocol now uses __next()__ instead of next(), - # but we keep it for backward compatibility of pgdb. - next = __next__ - - @staticmethod - def nextset() -> bool | None: - """Not supported.""" - raise NotSupportedError("The nextset() method is not supported") - - @staticmethod - def setinputsizes(sizes: Sequence[int]) -> None: - """Not supported.""" - pass # unsupported, but silently passed - - @staticmethod - def setoutputsize(size: int, column: int = 0) -> None: - """Not supported.""" - pass # unsupported, but silently passed - - @staticmethod - def row_factory(row: Sequence) -> Sequence: - """Process rows before they are returned. - - You can overwrite this statically with a custom row factory, or - you can build a row factory dynamically with build_row_factory(). - - For example, you can create a Cursor class that returns rows as - Python dictionaries like this: - - class DictCursor(pgdb.Cursor): - - def row_factory(self, row): - return {desc[0]: value - for desc, value in zip(self.description, row)} - - cur = DictCursor(con) # get one DictCursor instance or - con.cursor_type = DictCursor # always use DictCursor instances - """ - raise NotImplementedError - - def build_row_factory(self) -> Callable[[Sequence], Sequence] | None: - """Build a row factory based on the current description. - - This implementation builds a row factory for creating named tuples. - You can overwrite this method if you want to dynamically create - different row factories whenever the column description changes. - """ - names = self.colnames - return _row_factory(tuple(names)) if names else None - - -CursorDescription = namedtuple('CursorDescription', ( - 'name', 'type_code', 'display_size', 'internal_size', - 'precision', 'scale', 'null_ok')) - - -# *** Connection Objects *** - -class Connection: - """Connection object.""" - - # expose the exceptions as attributes on the connection object - Error = Error - Warning = Warning - InterfaceError = InterfaceError - DatabaseError = DatabaseError - InternalError = InternalError - OperationalError = OperationalError - ProgrammingError = ProgrammingError - IntegrityError = IntegrityError - DataError = DataError - NotSupportedError = NotSupportedError - - def __init__(self, cnx: Cnx) -> None: - """Create a database connection object.""" - self._cnx: Cnx | None = cnx # connection - self._tnx = False # transaction state - self.type_cache = TypeCache(cnx) - self.cursor_type = Cursor - self.autocommit = False - try: - self._cnx.source() - except Exception as e: - raise _op_error("Invalid connection") from e - - def __enter__(self) -> Connection: - """Enter the runtime context for the connection object. - - The runtime context can be used for running transactions. - - This also starts a transaction in autocommit mode. - """ - if self.autocommit: - cnx = self._cnx - if not cnx: - raise _op_error("Connection has been closed") - try: - cnx.source().execute("BEGIN") - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't start transaction") from e - else: - self._tnx = True - return self - - def __exit__(self, et: type[BaseException] | None, - ev: BaseException | None, tb: Any) -> None: - """Exit the runtime context for the connection object. - - This does not close the connection, but it ends a transaction. - """ - if et is None and ev is None and tb is None: - self.commit() - else: - self.rollback() - - def close(self) -> None: - """Close the connection object.""" - if not self._cnx: - raise _op_error("Connection has been closed") - if self._tnx: - with suppress(DatabaseError): - self.rollback() - self._cnx.close() - self._cnx = None - - @property - def closed(self) -> bool: - """Check whether the connection has been closed or is broken.""" - try: - return not self._cnx or self._cnx.status != 1 - except TypeError: - return True - - def commit(self) -> None: - """Commit any pending transaction to the database.""" - if not self._cnx: - raise _op_error("Connection has been closed") - if self._tnx: - self._tnx = False - try: - self._cnx.source().execute("COMMIT") - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't commit transaction") from e - - def rollback(self) -> None: - """Roll back to the start of any pending transaction.""" - if not self._cnx: - raise _op_error("Connection has been closed") - if self._tnx: - self._tnx = False - try: - self._cnx.source().execute("ROLLBACK") - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't rollback transaction") from e - - def cursor(self) -> Cursor: - """Return a new cursor object using the connection.""" - if not self._cnx: - raise _op_error("Connection has been closed") - try: - return self.cursor_type(self) - except Exception as e: - raise _op_error("Invalid connection") from e - - if shortcutmethods: # otherwise do not implement and document this - - def execute(self, operation: str, - parameters: Sequence | None = None) -> Cursor: - """Shortcut method to run an operation on an implicit cursor.""" - cursor = self.cursor() - cursor.execute(operation, parameters) - return cursor - - def executemany(self, operation: str, - seq_of_parameters: Sequence[Sequence | None] - ) -> Cursor: - """Shortcut method to run an operation against a sequence.""" - cursor = self.cursor() - cursor.executemany(operation, seq_of_parameters) - return cursor - - -# *** Module Interface *** - -def connect(dsn: str | None = None, - user: str | None = None, password: str | None = None, - host: str | None = None, database: str | None = None, - **kwargs: Any) -> Connection: - """Connect to a database.""" - # first get params from DSN - dbport = -1 - dbhost: str | None = "" - dbname: str | None = "" - dbuser: str | None = "" - dbpasswd: str | None = "" - dbopt: str | None = "" - if dsn: - try: - params = dsn.split(":", 4) - dbhost = params[0] - dbname = params[1] - dbuser = params[2] - dbpasswd = params[3] - dbopt = params[4] - except (AttributeError, IndexError, TypeError): - pass - - # override if necessary - if user is not None: - dbuser = user - if password is not None: - dbpasswd = password - if database is not None: - dbname = database - if host: - try: - params = host.split(":", 1) - dbhost = params[0] - dbport = int(params[1]) - except (AttributeError, IndexError, TypeError, ValueError): - pass - - # empty host is localhost - if dbhost == "": - dbhost = None - if dbuser == "": - dbuser = None - - # pass keyword arguments as connection info string - if kwargs: - kwarg_list = list(kwargs.items()) - kw_parts = [] - if dbname and '=' in dbname: - kw_parts.append(dbname) - else: - kwarg_list.insert(0, ('dbname', dbname)) - for kw, value in kwarg_list: - value = str(value) - if not value or ' ' in value: - value = value.replace('\\', '\\\\').replace("'", "\\'") - value = f"'{value}'" - kw_parts.append(f'{kw}={value}') - dbname = ' '.join(kw_parts) - # open the connection - cnx = get_cnx(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) - return Connection(cnx) - - -# *** Types Handling *** - -class DbType(frozenset): - """Type class for a couple of PostgreSQL data types. - - PostgreSQL is object-oriented: types are dynamic. - We must thus use type names as internal type codes. - """ - - def __new__(cls, values: str | Iterable[str]) -> DbType: - """Create new type object.""" - if isinstance(values, str): - values = values.split() - return super().__new__(cls, values) # type: ignore - - def __eq__(self, other: Any) -> bool: - """Check whether types are considered equal.""" - if isinstance(other, str): - if other.startswith('_'): - other = other[1:] - return other in self - return super().__eq__(other) - - def __ne__(self, other: Any) -> bool: - """Check whether types are not considered equal.""" - if isinstance(other, str): - if other.startswith('_'): - other = other[1:] - return other not in self - return super().__ne__(other) - - -class ArrayType: - """Type class for PostgreSQL array types.""" - - def __eq__(self, other: Any) -> bool: - if isinstance(other, str): - return other.startswith('_') - return isinstance(other, ArrayType) - - def __ne__(self, other: Any) -> bool: - if isinstance(other, str): - return not other.startswith('_') - return not isinstance(other, ArrayType) - - -class RecordType: - """Type class for PostgreSQL record types.""" - - def __eq__(self, other: Any) -> bool: - if isinstance(other, TypeCode): - # noinspection PyUnresolvedReferences - return other.type == 'c' - if isinstance(other, str): - return other == 'record' - return isinstance(other, RecordType) - - def __ne__(self, other: Any) -> bool: - if isinstance(other, TypeCode): - # noinspection PyUnresolvedReferences - return other.type != 'c' - if isinstance(other, str): - return other != 'record' - return not isinstance(other, RecordType) - - -# Mandatory type objects defined by DB-API 2 specs: - -STRING = DbType('char bpchar name text varchar') -BINARY = DbType('bytea') -NUMBER = DbType('int2 int4 serial int8 float4 float8 numeric money') -DATETIME = DbType('date time timetz timestamp timestamptz interval' - ' abstime reltime') # these are very old -ROWID = DbType('oid') - - -# Additional type objects (more specific): - -BOOL = DbType('bool') -SMALLINT = DbType('int2') -INTEGER = DbType('int2 int4 int8 serial') -LONG = DbType('int8') -FLOAT = DbType('float4 float8') -NUMERIC = DbType('numeric') -MONEY = DbType('money') -DATE = DbType('date') -TIME = DbType('time timetz') -TIMESTAMP = DbType('timestamp timestamptz') -INTERVAL = DbType('interval') -UUID = DbType('uuid') -HSTORE = DbType('hstore') -JSON = DbType('json jsonb') - -# Type object for arrays (also equate to their base types): - -ARRAY = ArrayType() - -# Type object for records (encompassing all composite types): - -RECORD = RecordType() - - -# Mandatory type helpers defined by DB-API 2 specs: - -def Date(year: int, month: int, day: int) -> date: # noqa: N802 - """Construct an object holding a date value.""" - return date(year, month, day) - - -def Time(hour: int, minute: int = 0, # noqa: N802 - second: int = 0, microsecond: int = 0, - tzinfo: tzinfo | None = None) -> time: - """Construct an object holding a time value.""" - return time(hour, minute, second, microsecond, tzinfo) - - -def Timestamp(year: int, month: int, day: int, # noqa: N802 - hour: int = 0, minute: int = 0, - second: int = 0, microsecond: int = 0, - tzinfo: tzinfo | None = None) -> datetime: - """Construct an object holding a time stamp value.""" - return datetime(year, month, day, hour, minute, - second, microsecond, tzinfo) - - -def DateFromTicks(ticks: float | None) -> date: # noqa: N802 - """Construct an object holding a date value from the given ticks value.""" - return Date(*localtime(ticks)[:3]) - - -def TimeFromTicks(ticks: float | None) -> time: # noqa: N802 - """Construct an object holding a time value from the given ticks value.""" - return Time(*localtime(ticks)[3:6]) - - -def TimestampFromTicks(ticks: float | None) -> datetime: # noqa: N802 - """Construct an object holding a time stamp from the given ticks value.""" - return Timestamp(*localtime(ticks)[:6]) - - -class Binary(bytes): - """Construct an object capable of holding a binary (long) string value.""" - - -# Additional type helpers for PyGreSQL: - -def Interval(days: int | float, # noqa: N802 - hours: int | float = 0, minutes: int | float = 0, - seconds: int | float = 0, microseconds: int | float = 0 - ) -> timedelta: - """Construct an object holding a time interval value.""" - return timedelta(days, hours=hours, minutes=minutes, - seconds=seconds, microseconds=microseconds) - - -Uuid = Uuid # Construct an object holding a UUID value - - -class Hstore(dict): - """Wrapper class for marking hstore values.""" - - _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') - _re_escape = regex(r'(["\\])') - - @classmethod - def _quote(cls, s: Any) -> Any: - if s is None: - return 'NULL' - if not isinstance(s, str): - s = str(s) - if not s: - return '""' - quote = cls._re_quote.search(s) - s = cls._re_escape.sub(r'\\\1', s) - if quote: - s = f'"{s}"' - return s - - def __str__(self) -> str: - """Create a printable representation of the hstore value.""" - q = self._quote - return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) - - -class Json: - """Construct a wrapper for holding an object serializable to JSON.""" - - def __init__(self, obj: Any, - encode: Callable[[Any], str] | None = None) -> None: - """Initialize the JSON object.""" - self.obj = obj - self.encode = encode or jsonencode - - def __str__(self) -> str: - """Create a printable representation of the JSON object.""" - obj = self.obj - if isinstance(obj, str): - return obj - return self.encode(obj) - - -class Literal: - """Construct a wrapper for holding a literal SQL string.""" - - def __init__(self, sql: str) -> None: - """Initialize literal SQL string.""" - self.sql = sql - - def __str__(self) -> str: - """Return a printable representation of the SQL string.""" - return self.sql - - __pg_repr__ = __str__ - - -# If run as script, print some information: - -if __name__ == '__main__': - print('PyGreSQL version', version) - print() - print(__doc__) +__version__ = version diff --git a/pgdb/adapt.py b/pgdb/adapt.py new file mode 100644 index 00000000..92b48a7e --- /dev/null +++ b/pgdb/adapt.py @@ -0,0 +1,237 @@ +"""Type helpers for adaptation of parameters.""" + +from __future__ import annotations + +from datetime import date, datetime, time, timedelta, tzinfo +from json import dumps as jsonencode +from re import compile as regex +from time import localtime +from typing import Any, Callable, Iterable +from uuid import UUID as Uuid # noqa: N811 + +from .typecode import TypeCode + +__all__ = [ + 'DbType', 'ArrayType', 'RecordType', + 'STRING', 'BINARY', 'NUMBER', 'DATETIME', 'ROWID', 'BOOL', 'SMALLINT', + 'INTEGER', 'LONG', 'FLOAT', 'NUMERIC', 'MONEY', 'DATE', 'TIME', + 'TIMESTAMP', 'INTERVAL', 'UUID', 'HSTORE', 'JSON', 'ARRAY', 'RECORD', + 'Date', 'Time', 'Timestamp', + 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks' + +] + + +class DbType(frozenset): + """Type class for a couple of PostgreSQL data types. + + PostgreSQL is object-oriented: types are dynamic. + We must thus use type names as internal type codes. + """ + + def __new__(cls, values: str | Iterable[str]) -> DbType: + """Create new type object.""" + if isinstance(values, str): + values = values.split() + return super().__new__(cls, values) # type: ignore + + def __eq__(self, other: Any) -> bool: + """Check whether types are considered equal.""" + if isinstance(other, str): + if other.startswith('_'): + other = other[1:] + return other in self + return super().__eq__(other) + + def __ne__(self, other: Any) -> bool: + """Check whether types are not considered equal.""" + if isinstance(other, str): + if other.startswith('_'): + other = other[1:] + return other not in self + return super().__ne__(other) + + +class ArrayType: + """Type class for PostgreSQL array types.""" + + def __eq__(self, other: Any) -> bool: + """Check whether arrays are equal.""" + if isinstance(other, str): + return other.startswith('_') + return isinstance(other, ArrayType) + + def __ne__(self, other: Any) -> bool: + """Check whether arrays are different.""" + if isinstance(other, str): + return not other.startswith('_') + return not isinstance(other, ArrayType) + + +class RecordType: + """Type class for PostgreSQL record types.""" + + def __eq__(self, other: Any) -> bool: + """Check whether records are equal.""" + if isinstance(other, TypeCode): + return other.type == 'c' + if isinstance(other, str): + return other == 'record' + return isinstance(other, RecordType) + + def __ne__(self, other: Any) -> bool: + """Check whether records are different.""" + if isinstance(other, TypeCode): + return other.type != 'c' + if isinstance(other, str): + return other != 'record' + return not isinstance(other, RecordType) + + +# Mandatory type objects defined by DB-API 2 specs: + +STRING = DbType('char bpchar name text varchar') +BINARY = DbType('bytea') +NUMBER = DbType('int2 int4 serial int8 float4 float8 numeric money') +DATETIME = DbType('date time timetz timestamp timestamptz interval' + ' abstime reltime') # these are very old +ROWID = DbType('oid') + + +# Additional type objects (more specific): + +BOOL = DbType('bool') +SMALLINT = DbType('int2') +INTEGER = DbType('int2 int4 int8 serial') +LONG = DbType('int8') +FLOAT = DbType('float4 float8') +NUMERIC = DbType('numeric') +MONEY = DbType('money') +DATE = DbType('date') +TIME = DbType('time timetz') +TIMESTAMP = DbType('timestamp timestamptz') +INTERVAL = DbType('interval') +UUID = DbType('uuid') +HSTORE = DbType('hstore') +JSON = DbType('json jsonb') + +# Type object for arrays (also equate to their base types): + +ARRAY = ArrayType() + +# Type object for records (encompassing all composite types): + +RECORD = RecordType() + + +# Mandatory type helpers defined by DB-API 2 specs: + +def Date(year: int, month: int, day: int) -> date: # noqa: N802 + """Construct an object holding a date value.""" + return date(year, month, day) + + +def Time(hour: int, minute: int = 0, # noqa: N802 + second: int = 0, microsecond: int = 0, + tzinfo: tzinfo | None = None) -> time: + """Construct an object holding a time value.""" + return time(hour, minute, second, microsecond, tzinfo) + + +def Timestamp(year: int, month: int, day: int, # noqa: N802 + hour: int = 0, minute: int = 0, + second: int = 0, microsecond: int = 0, + tzinfo: tzinfo | None = None) -> datetime: + """Construct an object holding a time stamp value.""" + return datetime(year, month, day, hour, minute, + second, microsecond, tzinfo) + + +def DateFromTicks(ticks: float | None) -> date: # noqa: N802 + """Construct an object holding a date value from the given ticks value.""" + return Date(*localtime(ticks)[:3]) + + +def TimeFromTicks(ticks: float | None) -> time: # noqa: N802 + """Construct an object holding a time value from the given ticks value.""" + return Time(*localtime(ticks)[3:6]) + + +def TimestampFromTicks(ticks: float | None) -> datetime: # noqa: N802 + """Construct an object holding a time stamp from the given ticks value.""" + return Timestamp(*localtime(ticks)[:6]) + + +class Binary(bytes): + """Construct an object capable of holding a binary (long) string value.""" + + +# Additional type helpers for PyGreSQL: + +def Interval(days: int | float, # noqa: N802 + hours: int | float = 0, minutes: int | float = 0, + seconds: int | float = 0, microseconds: int | float = 0 + ) -> timedelta: + """Construct an object holding a time interval value.""" + return timedelta(days, hours=hours, minutes=minutes, + seconds=seconds, microseconds=microseconds) + + +Uuid = Uuid # Construct an object holding a UUID value + + +class Hstore(dict): + """Wrapper class for marking hstore values.""" + + _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') + _re_escape = regex(r'(["\\])') + + @classmethod + def _quote(cls, s: Any) -> Any: + if s is None: + return 'NULL' + if not isinstance(s, str): + s = str(s) + if not s: + return '""' + quote = cls._re_quote.search(s) + s = cls._re_escape.sub(r'\\\1', s) + if quote: + s = f'"{s}"' + return s + + def __str__(self) -> str: + """Create a printable representation of the hstore value.""" + q = self._quote + return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) + + +class Json: + """Construct a wrapper for holding an object serializable to JSON.""" + + def __init__(self, obj: Any, + encode: Callable[[Any], str] | None = None) -> None: + """Initialize the JSON object.""" + self.obj = obj + self.encode = encode or jsonencode + + def __str__(self) -> str: + """Create a printable representation of the JSON object.""" + obj = self.obj + if isinstance(obj, str): + return obj + return self.encode(obj) + + +class Literal: + """Construct a wrapper for holding a literal SQL string.""" + + def __init__(self, sql: str) -> None: + """Initialize literal SQL string.""" + self.sql = sql + + def __str__(self) -> str: + """Return a printable representation of the SQL string.""" + return self.sql + + __pg_repr__ = __str__ \ No newline at end of file diff --git a/pgdb/cast.py b/pgdb/cast.py new file mode 100644 index 00000000..03367506 --- /dev/null +++ b/pgdb/cast.py @@ -0,0 +1,581 @@ +"""Internal type handling.""" + +from __future__ import annotations + +from collections import namedtuple +from datetime import date, datetime, time, timedelta +from decimal import Decimal as _Decimal +from functools import partial +from inspect import signature +from json import loads as jsondecode +from re import compile as regex +from typing import Any, Callable, ClassVar, Sequence +from uuid import UUID as Uuid # noqa: N811 + +from pg.core import Connection as Cnx +from pg.core import ( + ProgrammingError, + cast_array, + cast_hstore, + cast_record, + unescape_bytea, +) + +from .typecode import TypeCode + +__all__ = [ + 'Decimal', 'decimal_type', 'cast_bool', 'cast_money', + 'cast_int2vector', 'cast_date', 'cast_time', 'cast_interval', + 'cast_timetz', 'cast_timestamp', 'cast_timestamptz', + 'get_typecast', 'set_typecast', 'reset_typecast', + 'Typecasts', 'LocalTypecasts', 'TypeCache', 'FieldInfo' +] + + +Decimal: type = _Decimal + + +def get_args(func: Callable) -> list: + return list(signature(func).parameters) + + +# time zones used in Postgres timestamptz output +_timezones: dict[str, str] = { + 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', + 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', + 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' +} + + +def _timezone_as_offset(tz: str) -> str: + if tz.startswith(('+', '-')): + if len(tz) < 5: + return tz + '00' + return tz.replace(':', '') + return _timezones.get(tz, '+0000') + + +def decimal_type(decimal_type: type | None = None) -> type: + """Get or set global type to be used for decimal values. + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call con.type_cache.reset_typecast(). + """ + global Decimal + if decimal_type is not None: + Decimal = decimal_type + set_typecast('numeric', decimal_type) + return Decimal + + +def cast_bool(value: str) -> bool | None: + """Cast boolean value in database format to bool.""" + return value[0] in ('t', 'T') if value else None + + +def cast_money(value: str) -> _Decimal | None: + """Cast money value in database format to Decimal.""" + if not value: + return None + value = value.replace('(', '-') + return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) + + +def cast_int2vector(value: str) -> list[int]: + """Cast an int2vector value.""" + return [int(v) for v in value.split()] + + +def cast_date(value: str, cnx: Cnx) -> date: + """Cast a date value.""" + # The output format depends on the server setting DateStyle. The default + # setting ISO and the setting for German are actually unambiguous. The + # order of days and months in the other two settings is however ambiguous, + # so at least here we need to consult the setting to properly parse values. + if value == '-infinity': + return date.min + if value == 'infinity': + return date.max + values = value.split() + if values[-1] == 'BC': + return date.min + value = values[0] + if len(value) > 10: + return date.max + format = cnx.date_format() + return datetime.strptime(value, format).date() + + +def cast_time(value: str) -> time: + """Cast a time value.""" + fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + return datetime.strptime(value, fmt).time() + + +_re_timezone = regex('(.*)([+-].*)') + + +def cast_timetz(value: str) -> time: + """Cast a timetz value.""" + m = _re_timezone.match(value) + if m: + value, tz = m.groups() + else: + tz = '+0000' + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + value += _timezone_as_offset(tz) + format += '%z' + return datetime.strptime(value, format).timetz() + + +def cast_timestamp(value: str, cnx: Cnx) -> datetime: + """Cast a timestamp value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = cnx.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:5] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + else: + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +def cast_timestamptz(value: str, cnx: Cnx) -> datetime: + """Cast a timestamptz value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = cnx.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + values, tz = values[:-1], values[-1] + else: + if format.startswith('%Y-'): + m = _re_timezone.match(values[1]) + if m: + values[1], tz = m.groups() + else: + tz = '+0000' + else: + values, tz = values[:-1], values[-1] + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + values.append(_timezone_as_offset(tz)) + formats.append('%z') + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +_re_interval_sql_standard = regex( + '(?:([+-])?([0-9]+)-([0-9]+) ?)?' + '(?:([+-]?[0-9]+)(?!:) ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres = regex( + '(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres_verbose = regex( + '@ ?(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-]?[0-9]+) ?hours? ?)?' + '(?:([+-]?[0-9]+) ?mins? ?)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') + +_re_interval_iso_8601 = regex( + 'P(?:([+-]?[0-9]+)Y)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-]?[0-9]+)D)?' + '(?:T(?:([+-]?[0-9]+)H)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') + + +def cast_interval(value: str) -> timedelta: + """Cast an interval value.""" + # The output format depends on the server setting IntervalStyle, but it's + # not necessary to consult this setting to parse it. It's faster to just + # check all possible formats, and there is no ambiguity here. + m = _re_interval_iso_8601.match(value) + if m: + s = [v or '0' for v in m.groups()] + secs_ago = s.pop(5) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = -secs + usecs = -usecs + else: + m = _re_interval_postgres_verbose.match(value) + if m: + s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) + secs_ago = s.pop(5) == '-' + d = [-int(v) for v in s] if ago else [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = - secs + usecs = -usecs + else: + m = _re_interval_postgres.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + m = _re_interval_sql_standard.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + years_ago = s.pop(0) == '-' + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if years_ago: + years = -years + mons = -mons + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + raise ValueError(f'Cannot parse interval: {value}') + days += 365 * years + 30 * mons + return timedelta(days=days, hours=hours, minutes=mins, + seconds=secs, microseconds=usecs) + + +class Typecasts(dict): + """Dictionary mapping database types to typecast functions. + + The cast functions get passed the string representation of a value in + the database which they need to convert to a Python object. The + passed string will never be None since NULL values are already + handled before the cast function is called. + """ + + # the default cast functions + # (str functions are ignored but have been added for faster access) + defaults: ClassVar[dict[str, Callable]] = { + 'char': str, 'bpchar': str, 'name': str, + 'text': str, 'varchar': str, 'sql_identifier': str, + 'bool': cast_bool, 'bytea': unescape_bytea, + 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, + 'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode, + 'float4': float, 'float8': float, + 'numeric': Decimal, 'money': cast_money, + 'date': cast_date, 'interval': cast_interval, + 'time': cast_time, 'timetz': cast_timetz, + 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, + 'int2vector': cast_int2vector, 'uuid': Uuid, + 'anyarray': cast_array, 'record': cast_record} + + cnx: Cnx | None = None # for local connection specific instances + + def __missing__(self, typ: str) -> Callable | None: + """Create a cast function if it is not cached. + + Note that this class never raises a KeyError, + but returns None when no special cast function exists. + """ + if not isinstance(typ, str): + raise TypeError(f'Invalid type: {typ}') + cast = self.defaults.get(typ) + if cast: + # store default for faster access + cast = self._add_connection(cast) + self[typ] = cast + elif typ.startswith('_'): + # create array cast + base_cast = self[typ[1:]] + cast = self.create_array_cast(base_cast) + if base_cast: + # store only if base type exists + self[typ] = cast + return cast + + @staticmethod + def _needs_connection(func: Callable) -> bool: + """Check if a typecast function needs a connection argument.""" + try: + args = get_args(func) + except (TypeError, ValueError): + return False + return 'cnx' in args[1:] + + def _add_connection(self, cast: Callable) -> Callable: + """Add a connection argument to the typecast function if necessary.""" + if not self.cnx or not self._needs_connection(cast): + return cast + return partial(cast, cnx=self.cnx) + + def get(self, typ: str, default: Callable | None = None # type: ignore + ) -> Callable | None: + """Get the typecast function for the given database type.""" + return self[typ] or default + + def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a typecast function for the specified database type(s).""" + if isinstance(typ, str): + typ = [typ] + if cast is None: + for t in typ: + self.pop(t, None) + self.pop(f'_{t}', None) + else: + if not callable(cast): + raise TypeError("Cast parameter must be callable") + for t in typ: + self[t] = self._add_connection(cast) + self.pop(f'_{t}', None) + + def reset(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecasts for the specified type(s) to their defaults. + + When no type is specified, all typecasts will be reset. + """ + defaults = self.defaults + if typ is None: + self.clear() + self.update(defaults) + else: + if isinstance(typ, str): + typ = [typ] + for t in typ: + cast = defaults.get(t) + if cast: + self[t] = self._add_connection(cast) + t = f'_{t}' + cast = defaults.get(t) + if cast: + self[t] = self._add_connection(cast) + else: + self.pop(t, None) + else: + self.pop(t, None) + self.pop(f'_{t}', None) + + def create_array_cast(self, basecast: Callable) -> Callable: + """Create an array typecast for the given base cast.""" + cast_array = self['anyarray'] + + def cast(v: Any) -> list: + return cast_array(v, basecast) + return cast + + def create_record_cast(self, name: str, fields: Sequence[str], + casts: Sequence[str]) -> Callable: + """Create a named record typecast for the given fields and casts.""" + cast_record = self['record'] + record = namedtuple(name, fields) # type: ignore + + def cast(v: Any) -> record: + # noinspection PyArgumentList + return record(*cast_record(v, casts)) + return cast + + +_typecasts = Typecasts() # this is the global typecast dictionary + + +def get_typecast(typ: str) -> Callable | None: + """Get the global typecast function for the given database type.""" + return _typecasts.get(typ) + + +def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a global typecast function for the given database type(s). + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call con.type_cache.reset_typecast(). + """ + _typecasts.set(typ, cast) + + +def reset_typecast(typ: str | Sequence[str] | None = None) -> None: + """Reset the global typecasts for the given type(s) to their default. + + When no type is specified, all typecasts will be reset. + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call con.type_cache.reset_typecast(). + """ + _typecasts.reset(typ) + + +class LocalTypecasts(Typecasts): + """Map typecasts, including local composite types, to cast functions.""" + + defaults = _typecasts + + cnx: Cnx | None = None # set in connection specific instances + + def __missing__(self, typ: str) -> Callable | None: + """Create a cast function if it is not cached.""" + cast: Callable | None + if typ.startswith('_'): + base_cast = self[typ[1:]] + cast = self.create_array_cast(base_cast) + if base_cast: + self[typ] = cast + else: + cast = self.defaults.get(typ) + if cast: + cast = self._add_connection(cast) + self[typ] = cast + else: + fields = self.get_fields(typ) + if fields: + casts = [self[field.type] for field in fields] + field_names = [field.name for field in fields] + cast = self.create_record_cast(typ, field_names, casts) + self[typ] = cast + return cast + + # noinspection PyMethodMayBeStatic,PyUnusedLocal + def get_fields(self, typ: str) -> list[FieldInfo]: + """Return the fields for the given record type. + + This method will be replaced with a method that looks up the fields + using the type cache of the connection. + """ + return [] + + +FieldInfo = namedtuple('FieldInfo', ('name', 'type')) + + +class TypeCache(dict): + """Cache for database types. + + This cache maps type OIDs and names to TypeCode strings containing + important information on the associated database type. + """ + + def __init__(self, cnx: Cnx) -> None: + """Initialize type cache for connection.""" + super().__init__() + self._escape_string = cnx.escape_string + self._src = cnx.source() + self._typecasts = LocalTypecasts() + self._typecasts.get_fields = self.get_fields # type: ignore + self._typecasts.cnx = cnx + self._query_pg_type = ( + "SELECT oid, typname," + " typlen, typtype, typcategory, typdelim, typrelid" + " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) {}") + + def __missing__(self, key: int | str) -> TypeCode: + """Get the type info from the database if it is not cached.""" + oid: int | str + if isinstance(key, int): + oid = key + else: + if '.' not in key and '"' not in key: + key = f'"{key}"' + oid = f"'{self._escape_string(key)}'::pg_catalog.regtype" + try: + self._src.execute(self._query_pg_type.format(oid)) + except ProgrammingError: + res = None + else: + res = self._src.fetch(1) + if not res: + raise KeyError(f'Type {key} could not be found') + r = res[0] + type_code = TypeCode.create( + int(r[0]), r[1], int(r[2]), r[3], r[4], r[5], int(r[6])) + # noinspection PyUnresolvedReferences + self[type_code.oid] = self[str(type_code)] = type_code + return type_code + + def get(self, key: int | str, # type: ignore + default: TypeCode | None = None) -> TypeCode | None: + """Get the type even if it is not cached.""" + try: + return self[key] + except KeyError: + return default + + def get_fields(self, typ: int | str | TypeCode) -> list[FieldInfo] | None: + """Get the names and types of the fields of composite types.""" + if isinstance(typ, TypeCode): + relid = typ.relid + else: + type_code = self.get(typ) + if not type_code: + return None + relid = type_code.relid + if not relid: + return None # this type is not composite + self._src.execute( + "SELECT attname, atttypid" # noqa: S608 + " FROM pg_catalog.pg_attribute" + f" WHERE attrelid OPERATOR(pg_catalog.=) {relid}" + " AND attnum OPERATOR(pg_catalog.>) 0" + " AND NOT attisdropped ORDER BY attnum") + return [FieldInfo(name, self.get(int(oid))) + for name, oid in self._src.fetch(-1)] + + def get_typecast(self, typ: str) -> Callable | None: + """Get the typecast function for the given database type.""" + return self._typecasts[typ] + + def set_typecast(self, typ: str | Sequence[str], + cast: Callable | None) -> None: + """Set a typecast function for the specified database type(s).""" + self._typecasts.set(typ, cast) + + def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecast function for the specified database type(s).""" + self._typecasts.reset(typ) + + def typecast(self, value: Any, typ: str) -> Any: + """Cast the given value according to the given database type.""" + if value is None: + # for NULL values, no typecast is necessary + return None + cast = self._typecasts[typ] + if cast is None or cast is str: + # no typecast is necessary + return value + return cast(value) + + def get_row_caster(self, types: Sequence[str]) -> Callable: + """Get a typecast function for a complete row of values.""" + typecasts = self._typecasts + casts = [typecasts[typ] for typ in types] + casts = [cast if cast is not str else None for cast in casts] + + def row_caster(row: Sequence) -> Sequence: + return [value if cast is None or value is None else cast(value) + for cast, value in zip(casts, row)] + + return row_caster \ No newline at end of file diff --git a/pgdb/connect.py b/pgdb/connect.py new file mode 100644 index 00000000..73b96a36 --- /dev/null +++ b/pgdb/connect.py @@ -0,0 +1,74 @@ +"""The DB API 2 connect function.""" + +from __future__ import annotations + +from typing import Any + +from pg.core import connect as get_cnx + +from .connection import Connection + +__all__ = ['connect'] + +def connect(dsn: str | None = None, + user: str | None = None, password: str | None = None, + host: str | None = None, database: str | None = None, + **kwargs: Any) -> Connection: + """Connect to a database.""" + # first get params from DSN + dbport = -1 + dbhost: str | None = "" + dbname: str | None = "" + dbuser: str | None = "" + dbpasswd: str | None = "" + dbopt: str | None = "" + if dsn: + try: + params = dsn.split(":", 4) + dbhost = params[0] + dbname = params[1] + dbuser = params[2] + dbpasswd = params[3] + dbopt = params[4] + except (AttributeError, IndexError, TypeError): + pass + + # override if necessary + if user is not None: + dbuser = user + if password is not None: + dbpasswd = password + if database is not None: + dbname = database + if host: + try: + params = host.split(":", 1) + dbhost = params[0] + dbport = int(params[1]) + except (AttributeError, IndexError, TypeError, ValueError): + pass + + # empty host is localhost + if dbhost == "": + dbhost = None + if dbuser == "": + dbuser = None + + # pass keyword arguments as connection info string + if kwargs: + kwarg_list = list(kwargs.items()) + kw_parts = [] + if dbname and '=' in dbname: + kw_parts.append(dbname) + else: + kwarg_list.insert(0, ('dbname', dbname)) + for kw, value in kwarg_list: + value = str(value) + if not value or ' ' in value: + value = value.replace('\\', '\\\\').replace("'", "\\'") + value = f"'{value}'" + kw_parts.append(f'{kw}={value}') + dbname = ' '.join(kw_parts) + # open the connection + cnx = get_cnx(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) + return Connection(cnx) diff --git a/pgdb/connection.py b/pgdb/connection.py new file mode 100644 index 00000000..17d32bcc --- /dev/null +++ b/pgdb/connection.py @@ -0,0 +1,156 @@ +"""The DB API 2 Connection objects.""" + +from __future__ import annotations + +from contextlib import suppress +from typing import Any, Sequence + +from pg.core import Connection as Cnx +from pg.core import ( + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, +) +from pg.error import op_error + +from .cast import TypeCache +from .constants import shortcutmethods +from .cursor import Cursor + +__all__ = ['Connection'] + +class Connection: + """Connection object.""" + + # expose the exceptions as attributes on the connection object + Error = Error + Warning = Warning + InterfaceError = InterfaceError + DatabaseError = DatabaseError + InternalError = InternalError + OperationalError = OperationalError + ProgrammingError = ProgrammingError + IntegrityError = IntegrityError + DataError = DataError + NotSupportedError = NotSupportedError + + def __init__(self, cnx: Cnx) -> None: + """Create a database connection object.""" + self._cnx: Cnx | None = cnx # connection + self._tnx = False # transaction state + self.type_cache = TypeCache(cnx) + self.cursor_type = Cursor + self.autocommit = False + try: + self._cnx.source() + except Exception as e: + raise op_error("Invalid connection") from e + + def __enter__(self) -> Connection: + """Enter the runtime context for the connection object. + + The runtime context can be used for running transactions. + + This also starts a transaction in autocommit mode. + """ + if self.autocommit: + cnx = self._cnx + if not cnx: + raise op_error("Connection has been closed") + try: + cnx.source().execute("BEGIN") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't start transaction") from e + else: + self._tnx = True + return self + + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: + """Exit the runtime context for the connection object. + + This does not close the connection, but it ends a transaction. + """ + if et is None and ev is None and tb is None: + self.commit() + else: + self.rollback() + + def close(self) -> None: + """Close the connection object.""" + if not self._cnx: + raise op_error("Connection has been closed") + if self._tnx: + with suppress(DatabaseError): + self.rollback() + self._cnx.close() + self._cnx = None + + @property + def closed(self) -> bool: + """Check whether the connection has been closed or is broken.""" + try: + return not self._cnx or self._cnx.status != 1 + except TypeError: + return True + + def commit(self) -> None: + """Commit any pending transaction to the database.""" + if not self._cnx: + raise op_error("Connection has been closed") + if self._tnx: + self._tnx = False + try: + self._cnx.source().execute("COMMIT") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't commit transaction") from e + + def rollback(self) -> None: + """Roll back to the start of any pending transaction.""" + if not self._cnx: + raise op_error("Connection has been closed") + if self._tnx: + self._tnx = False + try: + self._cnx.source().execute("ROLLBACK") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't rollback transaction") from e + + def cursor(self) -> Cursor: + """Return a new cursor object using the connection.""" + if not self._cnx: + raise op_error("Connection has been closed") + try: + return self.cursor_type(self) + except Exception as e: + raise op_error("Invalid connection") from e + + if shortcutmethods: # otherwise do not implement and document this + + def execute(self, operation: str, + parameters: Sequence | None = None) -> Cursor: + """Shortcut method to run an operation on an implicit cursor.""" + cursor = self.cursor() + cursor.execute(operation, parameters) + return cursor + + def executemany(self, operation: str, + seq_of_parameters: Sequence[Sequence | None] + ) -> Cursor: + """Shortcut method to run an operation against a sequence.""" + cursor = self.cursor() + cursor.executemany(operation, seq_of_parameters) + return cursor \ No newline at end of file diff --git a/pgdb/constants.py b/pgdb/constants.py new file mode 100644 index 00000000..e6547f9c --- /dev/null +++ b/pgdb/constants.py @@ -0,0 +1,14 @@ +"""The DB API 2 module constants.""" + +# compliant with DB API 2.0 +apilevel = '2.0' + +# module may be shared, but not connections +threadsafety = 1 + +# this module use extended python format codes +paramstyle = 'pyformat' + +# shortcut methods have been excluded from DB API 2 and +# are not recommended by the DB SIG, but they can be handy +shortcutmethods = 1 diff --git a/pgdb/cursor.py b/pgdb/cursor.py new file mode 100644 index 00000000..753f4691 --- /dev/null +++ b/pgdb/cursor.py @@ -0,0 +1,645 @@ +"""The DB API 2 Cursor object.""" + +from __future__ import annotations + +from collections import namedtuple +from collections.abc import Iterable +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from math import isinf, isnan +from typing import TYPE_CHECKING, Any, Callable, Generator, Mapping, Sequence +from uuid import UUID as Uuid # noqa: N811 + +from pg.core import ( + RESULT_DQL, + DatabaseError, + Error, + InterfaceError, + NotSupportedError, +) +from pg.core import Connection as Cnx +from pg.error import db_error, if_error, op_error +from pg.helpers import QuoteDict, RowCache + +from .adapt import Binary, Hstore, Json, Literal +from .cast import TypeCache +from .typecode import TypeCode + +if TYPE_CHECKING: + from .connection import Connection + +__all__ = ['Cursor', 'CursorDescription'] + + +class Cursor: + """Cursor object.""" + + def __init__(self, connection: Connection) -> None: + """Create a cursor object for the database connection.""" + self.connection = self._connection = connection + cnx = connection._cnx + if not cnx: + raise op_error("Connection has been closed") + self._cnx: Cnx = cnx + self.type_cache: TypeCache = connection.type_cache + self._src = self._cnx.source() + # the official attribute for describing the result columns + self._description: list[CursorDescription] | bool | None = None + if self.row_factory is Cursor.row_factory: + # the row factory needs to be determined dynamically + self.row_factory = None # type: ignore + else: + self.build_row_factory = None # type: ignore + self.rowcount: int | None = -1 + self.arraysize: int = 1 + self.lastrowid: int | None = None + + def __iter__(self) -> Cursor: + """Make cursor compatible to the iteration protocol.""" + return self + + def __enter__(self) -> Cursor: + """Enter the runtime context for the cursor object.""" + return self + + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: + """Exit the runtime context for the cursor object.""" + self.close() + + def _quote(self, value: Any) -> Any: + """Quote value depending on its type.""" + if value is None: + return 'NULL' + if isinstance(value, (Hstore, Json)): + value = str(value) + if isinstance(value, (bytes, str)): + cnx = self._cnx + if isinstance(value, Binary): + value = cnx.escape_bytea(value).decode('ascii') + else: + value = cnx.escape_string(value) + return f"'{value}'" + if isinstance(value, float): + if isinf(value): + return "'-Infinity'" if value < 0 else "'Infinity'" + if isnan(value): + return "'NaN'" + return value + if isinstance(value, (int, Decimal, Literal)): + return value + if isinstance(value, datetime): + if value.tzinfo: + return f"'{value}'::timestamptz" + return f"'{value}'::timestamp" + if isinstance(value, date): + return f"'{value}'::date" + if isinstance(value, time): + if value.tzinfo: + return f"'{value}'::timetz" + return f"'{value}'::time" + if isinstance(value, timedelta): + return f"'{value}'::interval" + if isinstance(value, Uuid): + return f"'{value}'::uuid" + if isinstance(value, list): + # Quote value as an ARRAY constructor. This is better than using + # an array literal because it carries the information that this is + # an array and not a string. One issue with this syntax is that + # you need to add an explicit typecast when passing empty arrays. + # The ARRAY keyword is actually only necessary at the top level. + if not value: # exception for empty array + return "'{}'" + q = self._quote + v = ','.join(str(q(v)) for v in value) + return f'ARRAY[{v}]' + if isinstance(value, tuple): + # Quote as a ROW constructor. This is better than using a record + # literal because it carries the information that this is a record + # and not a string. We don't use the keyword ROW in order to make + # this usable with the IN syntax as well. It is only necessary + # when the records has a single column which is not really useful. + q = self._quote + v = ','.join(str(q(v)) for v in value) + return f'({v})' + try: # noinspection PyUnresolvedReferences + value = value.__pg_repr__() + except AttributeError as e: + raise InterfaceError( + f'Do not know how to adapt type {type(value)}') from e + if isinstance(value, (tuple, list)): + value = self._quote(value) + return value + + def _quoteparams(self, string: str, + parameters: Mapping | Sequence | None) -> str: + """Quote parameters. + + This function works for both mappings and sequences. + + The function should be used even when there are no parameters, + so that we have a consistent behavior regarding percent signs. + """ + if not parameters: + try: + return string % () # unescape literal quotes if possible + except (TypeError, ValueError): + return string # silently accept unescaped quotes + if isinstance(parameters, dict): + parameters = QuoteDict(parameters) + parameters.quote = self._quote + else: + parameters = tuple(map(self._quote, parameters)) + return string % parameters + + def _make_description(self, info: tuple[int, str, int, int, int] + ) -> CursorDescription: + """Make the description tuple for the given field info.""" + name, typ, size, mod = info[1:] + type_code = self.type_cache[typ] + if mod > 0: + mod -= 4 + precision: int | None + scale: int | None + if type_code == 'numeric': + precision, scale = mod >> 16, mod & 0xffff + size = precision + else: + if not size: + size = type_code.size + if size == -1: + size = mod + precision = scale = None + return CursorDescription( + name, type_code, None, size, precision, scale, None) + + @property + def description(self) -> list[CursorDescription] | None: + """Read-only attribute describing the result columns.""" + description = self._description + if description is None: + return None + if not isinstance(description, list): + make = self._make_description + description = [make(info) for info in self._src.listinfo()] + self._description = description + return description + + @property + def colnames(self) -> Sequence[str] | None: + """Unofficial convenience method for getting the column names.""" + description = self.description + return None if description is None else [d[0] for d in description] + + @property + def coltypes(self) -> Sequence[TypeCode] | None: + """Unofficial convenience method for getting the column types.""" + description = self.description + return None if description is None else [d[1] for d in description] + + def close(self) -> None: + """Close the cursor object.""" + self._src.close() + + def execute(self, operation: str, parameters: Sequence | None = None + ) -> Cursor: + """Prepare and execute a database operation (query or command).""" + # The parameters may also be specified as list of tuples to e.g. + # insert multiple rows in a single operation, but this kind of + # usage is deprecated. We make several plausibility checks because + # tuples can also be passed with the meaning of ROW constructors. + if (parameters and isinstance(parameters, list) + and len(parameters) > 1 + and all(isinstance(p, tuple) for p in parameters) + and all(len(p) == len(parameters[0]) for p in parameters[1:])): + return self.executemany(operation, parameters) + # not a list of tuples + return self.executemany(operation, [parameters]) + + def executemany(self, operation: str, + seq_of_parameters: Sequence[Sequence | None]) -> Cursor: + """Prepare operation and execute it against a parameter sequence.""" + if not seq_of_parameters: + # don't do anything without parameters + return self + self._description = None + self.rowcount = -1 + # first try to execute all queries + rowcount = 0 + sql = "BEGIN" + try: + if not self._connection._tnx and not self._connection.autocommit: + try: + self._src.execute(sql) + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't start transaction") from e + else: + self._connection._tnx = True + for parameters in seq_of_parameters: + sql = operation + sql = self._quoteparams(sql, parameters) + rows = self._src.execute(sql) + if rows: # true if not DML + rowcount += rows + else: + self.rowcount = -1 + except DatabaseError: + raise # database provides error message + except Error as err: + # noinspection PyTypeChecker + raise if_error(f"Error in '{sql}': '{err}'") from err + except Exception as err: + raise op_error(f"Internal error in '{sql}': {err}") from err + # then initialize result raw count and description + if self._src.resulttype == RESULT_DQL: + self._description = True # fetch on demand + self.rowcount = self._src.ntuples + self.lastrowid = None + build_row_factory = self.build_row_factory + if build_row_factory: # type: ignore + self.row_factory = build_row_factory() # type: ignore + else: + self.rowcount = rowcount + self.lastrowid = self._src.oidstatus() + # return the cursor object, so you can write statements such as + # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)" + return self + + def fetchone(self) -> Sequence | None: + """Fetch the next row of a query result set.""" + res = self.fetchmany(1, False) + try: + return res[0] + except IndexError: + return None + + def fetchall(self) -> Sequence[Sequence]: + """Fetch all (remaining) rows of a query result.""" + return self.fetchmany(-1, False) + + def fetchmany(self, size: int | None = None, keep: bool = False + ) -> Sequence[Sequence]: + """Fetch the next set of rows of a query result. + + The number of rows to fetch per call is specified by the + size parameter. If it is not given, the cursor's arraysize + determines the number of rows to be fetched. If you set + the keep parameter to true, this is kept as new arraysize. + """ + if size is None: + size = self.arraysize + if keep: + self.arraysize = size + try: + result = self._src.fetch(size) + except DatabaseError: + raise + except Error as err: + raise db_error(str(err)) from err + row_factory = self.row_factory + coltypes = self.coltypes + if coltypes is None: + # cannot determine column types, return raw result + return [row_factory(row) for row in result] + if len(result) > 5: + # optimize the case where we really fetch many values + # by looking up all type casting functions upfront + cast_row = self.type_cache.get_row_caster(coltypes) + return [row_factory(cast_row(row)) for row in result] + cast_value = self.type_cache.typecast + return [row_factory([cast_value(value, typ) + for typ, value in zip(coltypes, row)]) for row in result] + + def callproc(self, procname: str, parameters: Sequence | None = None + ) -> Sequence | None: + """Call a stored database procedure with the given name. + + The sequence of parameters must contain one entry for each input + argument that the procedure expects. The result of the call is the + same as this input sequence; replacement of output and input/output + parameters in the return value is currently not supported. + + The procedure may also provide a result set as output. These can be + requested through the standard fetch methods of the cursor. + """ + n = len(parameters) if parameters else 0 + s = ','.join(n * ['%s']) + query = f'select * from "{procname}"({s})' # noqa: S608 + self.execute(query, parameters) + return parameters + + # noinspection PyShadowingBuiltins + def copy_from(self, stream: Any, table: str, + format: str | None = None, sep: str | None = None, + null: str | None = None, size: int | None = None, + columns: Sequence[str] | None = None) -> Cursor: + """Copy data from an input stream to the specified table. + + The input stream can be a file-like object with a read() method or + it can also be an iterable returning a row or multiple rows of input + on each iteration. + + The format must be 'text', 'csv' or 'binary'. The sep option sets the + column separator (delimiter) used in the non binary formats. + The null option sets the textual representation of NULL in the input. + + The size option sets the size of the buffer used when reading data + from file-like objects. + + The copy operation can be restricted to a subset of columns. If no + columns are specified, all of them will be copied. + """ + binary_format = format == 'binary' + try: + read = stream.read + except AttributeError as e: + if size: + raise ValueError( + "Size must only be set for file-like objects") from e + input_type: type | tuple[type, ...] + type_name: str + if binary_format: + input_type = bytes + type_name = 'byte strings' + else: + input_type = (bytes, str) + type_name = 'strings' + + if isinstance(stream, (bytes, str)): + if not isinstance(stream, input_type): + raise ValueError(f"The input must be {type_name}") from e + if not binary_format: + if isinstance(stream, str): + if not stream.endswith('\n'): + stream += '\n' + else: + if not stream.endswith(b'\n'): + stream += b'\n' + + def chunks() -> Generator: + yield stream + + elif isinstance(stream, Iterable): + + def chunks() -> Generator: + for chunk in stream: + if not isinstance(chunk, input_type): + raise ValueError( + f"Input stream must consist of {type_name}") + if isinstance(chunk, str): + if not chunk.endswith('\n'): + chunk += '\n' + else: + if not chunk.endswith(b'\n'): + chunk += b'\n' + yield chunk + + else: + raise TypeError("Need an input stream to copy from") from e + else: + if size is None: + size = 8192 + elif not isinstance(size, int): + raise TypeError("The size option must be an integer") + if size > 0: + + def chunks() -> Generator: + while True: + buffer = read(size) + yield buffer + if not buffer or len(buffer) < size: + break + + else: + + def chunks() -> Generator: + yield read() + + if not table or not isinstance(table, str): + raise TypeError("Need a table to copy to") + if table.lower().startswith('select '): + raise ValueError("Must specify a table, not a query") + cnx = self._cnx + table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) + operation_parts = [f'copy {table}'] + options = [] + parameters = [] + if format is not None: + if not isinstance(format, str): + raise TypeError("The format option must be be a string") + if format not in ('text', 'csv', 'binary'): + raise ValueError("Invalid format") + options.append(f'format {format}') + if sep is not None: + if not isinstance(sep, str): + raise TypeError("The sep option must be a string") + if format == 'binary': + raise ValueError( + "The sep option is not allowed with binary format") + if len(sep) != 1: + raise ValueError( + "The sep option must be a single one-byte character") + options.append('delimiter %s') + parameters.append(sep) + if null is not None: + if not isinstance(null, str): + raise TypeError("The null option must be a string") + options.append('null %s') + parameters.append(null) + if columns: + if not isinstance(columns, str): + columns = ','.join(map(cnx.escape_identifier, columns)) + operation_parts.append(f'({columns})') + operation_parts.append("from stdin") + if options: + operation_parts.append(f"({','.join(options)})") + operation = ' '.join(operation_parts) + + putdata = self._src.putdata + self.execute(operation, parameters) + + try: + for chunk in chunks(): + putdata(chunk) + except BaseException as error: + self.rowcount = -1 + # the following call will re-raise the error + putdata(error) + else: + rowcount = putdata(None) + self.rowcount = -1 if rowcount is None else rowcount + + # return the cursor object, so you can chain operations + return self + + # noinspection PyShadowingBuiltins + def copy_to(self, stream: Any, table: str, + format: str | None = None, sep: str | None = None, + null: str | None = None, decode: bool | None = None, + columns: Sequence[str] | None = None) -> Cursor | Generator: + """Copy data from the specified table to an output stream. + + The output stream can be a file-like object with a write() method or + it can also be None, in which case the method will return a generator + yielding a row on each iteration. + + Output will be returned as byte strings unless you set decode to true. + + Note that you can also use a select query instead of the table name. + + The format must be 'text', 'csv' or 'binary'. The sep option sets the + column separator (delimiter) used in the non binary formats. + The null option sets the textual representation of NULL in the output. + + The copy operation can be restricted to a subset of columns. If no + columns are specified, all of them will be copied. + """ + binary_format = format == 'binary' + if stream is None: + write = None + else: + try: + write = stream.write + except AttributeError as e: + raise TypeError("Need an output stream to copy to") from e + if not table or not isinstance(table, str): + raise TypeError("Need a table to copy to") + cnx = self._cnx + if table.lower().startswith('select '): + if columns: + raise ValueError("Columns must be specified in the query") + table = f'({table})' + else: + table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) + operation_parts = [f'copy {table}'] + options = [] + parameters = [] + if format is not None: + if not isinstance(format, str): + raise TypeError("The format option must be a string") + if format not in ('text', 'csv', 'binary'): + raise ValueError("Invalid format") + options.append(f'format {format}') + if sep is not None: + if not isinstance(sep, str): + raise TypeError("The sep option must be a string") + if binary_format: + raise ValueError( + "The sep option is not allowed with binary format") + if len(sep) != 1: + raise ValueError( + "The sep option must be a single one-byte character") + options.append('delimiter %s') + parameters.append(sep) + if null is not None: + if not isinstance(null, str): + raise TypeError("The null option must be a string") + options.append('null %s') + parameters.append(null) + if decode is None: + decode = format != 'binary' + else: + if not isinstance(decode, (int, bool)): + raise TypeError("The decode option must be a boolean") + if decode and binary_format: + raise ValueError( + "The decode option is not allowed with binary format") + if columns: + if not isinstance(columns, str): + columns = ','.join(map(cnx.escape_identifier, columns)) + operation_parts.append(f'({columns})') + + operation_parts.append("to stdout") + if options: + operation_parts.append(f"({','.join(options)})") + operation = ' '.join(operation_parts) + + getdata = self._src.getdata + self.execute(operation, parameters) + + def copy() -> Generator: + self.rowcount = 0 + while True: + row = getdata(decode) + if isinstance(row, int): + if self.rowcount != row: + self.rowcount = row + break + self.rowcount += 1 + yield row + + if write is None: + # no input stream, return the generator + return copy() + + # write the rows to the file-like input stream + for row in copy(): + # noinspection PyUnboundLocalVariable + write(row) + + # return the cursor object, so you can chain operations + return self + + def __next__(self) -> Sequence: + """Return the next row (support for the iteration protocol).""" + res = self.fetchone() + if res is None: + raise StopIteration + return res + + # Note that the iterator protocol now uses __next()__ instead of next(), + # but we keep it for backward compatibility of pgdb. + next = __next__ + + @staticmethod + def nextset() -> bool | None: + """Not supported.""" + raise NotSupportedError("The nextset() method is not supported") + + @staticmethod + def setinputsizes(sizes: Sequence[int]) -> None: + """Not supported.""" + pass # unsupported, but silently passed + + @staticmethod + def setoutputsize(size: int, column: int = 0) -> None: + """Not supported.""" + pass # unsupported, but silently passed + + @staticmethod + def row_factory(row: Sequence) -> Sequence: + """Process rows before they are returned. + + You can overwrite this statically with a custom row factory, or + you can build a row factory dynamically with build_row_factory(). + + For example, you can create a Cursor class that returns rows as + Python dictionaries like this: + + class DictCursor(pgdb.Cursor): + + def row_factory(self, row): + return {desc[0]: value + for desc, value in zip(self.description, row)} + + cur = DictCursor(con) # get one DictCursor instance or + con.cursor_type = DictCursor # always use DictCursor instances + """ + raise NotImplementedError + + def build_row_factory(self) -> Callable[[Sequence], Sequence] | None: + """Build a row factory based on the current description. + + This implementation builds a row factory for creating named tuples. + You can overwrite this method if you want to dynamically create + different row factories whenever the column description changes. + """ + names = self.colnames + return RowCache.row_factory(tuple(names)) if names else None + + +CursorDescription = namedtuple('CursorDescription', ( + 'name', 'type_code', 'display_size', 'internal_size', + 'precision', 'scale', 'null_ok')) diff --git a/pgdb/typecode.py b/pgdb/typecode.py new file mode 100644 index 00000000..fcfb4620 --- /dev/null +++ b/pgdb/typecode.py @@ -0,0 +1,34 @@ +"""Support for DB API 2 type codes.""" + +from __future__ import annotations + +__all__ = ['TypeCode'] + + +class TypeCode(str): + """Class representing the type_code used by the DB-API 2.0. + + TypeCode objects are strings equal to the PostgreSQL type name, + but carry some additional information. + """ + + oid: int + len: int + type: str + category: str + delim: str + relid: int + + # noinspection PyShadowingBuiltins + @classmethod + def create(cls, oid: int, name: str, len: int, type: str, category: str, + delim: str, relid: int) -> TypeCode: + """Create a type code for a PostgreSQL data type.""" + self = cls(name) + self.oid = oid + self.len = len + self.type = type + self.category = category + self.delim = delim + self.relid = relid + return self \ No newline at end of file diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 0c038f72..bf3c5718 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -97,12 +97,13 @@ def tearDown(self): def _connect(self): try: - r = self.driver.connect( - *self.connect_args, **self.connect_kw_args - ) + con = self.driver.connect( + *self.connect_args, **self.connect_kw_args) except AttributeError: self.fail("No connect method found in self.driver module") - return r + if not isinstance(con, self.driver.Connection): + self.fail("The connect method does not return a Connection") + return con def test_connect(self): con = self._connect() diff --git a/tests/test_classic.py b/tests/test_classic.py index a6f78197..3bf0fe5c 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -148,7 +148,7 @@ def test_sqlstate(self): try: db.query("INSERT INTO _test_schema VALUES (1234)") except DatabaseError as error: - self.assertTrue(isinstance(error, IntegrityError)) + self.assertIsInstance(error, IntegrityError) # the SQLSTATE error code for unique violation is 23505 # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '23505') @@ -238,7 +238,7 @@ def test_notify(self, options=None): self.assertTrue(arg_dict['called']) self.assertEqual(arg_dict['event'], 'event_1') self.assertEqual(arg_dict['extra'], 'payload 1') - self.assertTrue(isinstance(arg_dict['pid'], int)) + self.assertIsInstance(arg_dict['pid'], int) self.assertFalse(self.notify_timeout) arg_dict['called'] = False self.assertTrue(thread.is_alive()) @@ -257,7 +257,7 @@ def test_notify(self, options=None): self.assertTrue(arg_dict['called']) self.assertEqual(arg_dict['event'], 'stop_event_1') self.assertEqual(arg_dict['extra'], 'payload 2') - self.assertTrue(isinstance(arg_dict['pid'], int)) + self.assertIsInstance(arg_dict['pid'], int) self.assertFalse(self.notify_timeout) thread.join(5) self.assertFalse(thread.is_alive()) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index eca64afd..dcb7a5e2 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2250,7 +2250,7 @@ def test_get_notify(self): self.assertIsNone(self.c.getnotify()) query("notify test_notify, 'test_payload'") r = getnotify() - self.assertTrue(isinstance(r, tuple)) + self.assertIsInstance(r, tuple) self.assertEqual(len(r), 3) self.assertIsInstance(r[0], str) self.assertIsInstance(r[1], int) @@ -2636,11 +2636,12 @@ def test_set_bytea_escaped(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'data') - def test_set_row_factory_size(self): + def test_change_row_factory_cache_size(self): + cache = pg.RowCache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): - pg.set_row_factory_size(maxsize) + cache.change_size(maxsize) for _i in range(3): for q in queries: r = query(q).namedresult()[0] @@ -2650,12 +2651,11 @@ def test_set_row_factory_size(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - from pg.helpers import _row_factory - info = _row_factory.cache_info() + info = cache.row_factory.cache_info() self.assertEqual(info.maxsize, maxsize) self.assertEqual(info.hits + info.misses, 6) - self.assertEqual( - info.hits, 0 if maxsize is not None and maxsize < 2 else 4) + self.assertEqual(info.hits, + 0 if maxsize is not None and maxsize < 2 else 4) class TestStandaloneEscapeFunctions(unittest.TestCase): diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 2ddde601..0755d95e 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -3164,7 +3164,7 @@ def test_context_manager(self): query("insert into test_table values (6)") query("insert into test_table values (-1)") except pg.IntegrityError as error: - self.assertTrue('check' in str(error)) + self.assertIn('check', str(error)) with self.db: query("insert into test_table values (7)") r = [r[0] for r in query( @@ -3276,7 +3276,8 @@ def test_upsert_bytea(self): if pg.get_bytea_escaped(): self.assertNotEqual(data, s) self.assertIsInstance(data, str) - data = pg.unescape_bytea(data) # type: ignore + assert isinstance(data, str) # type guard + data = pg.unescape_bytea(data) self.assertIsInstance(data, bytes) self.assertEqual(data, s) d['data'] = None diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 2e731c6e..ef4857d3 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -5,6 +5,7 @@ import gc import unittest from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal from typing import Any, ClassVar from uuid import UUID as Uuid # noqa: N811 @@ -443,7 +444,6 @@ def test_cursor_invalidation(self): self.assertRaises(pgdb.OperationalError, cur.fetchone) def test_fetch_2_rows(self): - Decimal = pgdb.decimal_type() # noqa: N806 values = ('test', pgdb.Binary(b'\xff\x52\xb2'), True, 5, 6, 5.7, Decimal('234.234234'), Decimal('75.45'), pgdb.Date(2011, 7, 17), pgdb.Time(15, 47, 42), @@ -536,7 +536,7 @@ def test_sqlstate(self): try: cur.execute("select 1/0") except pgdb.DatabaseError as error: - self.assertTrue(isinstance(error, pgdb.DataError)) + self.assertIsInstance(error, pgdb.DataError) # the SQLSTATE error code for division by zero is 22012 # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') @@ -575,9 +575,9 @@ def test_float(self): if isinf(inval): # type: ignore self.assertTrue(isinf(outval)) if inval < 0: # type: ignore - self.assertTrue(outval < 0) + self.assertLess(outval, 0) else: - self.assertTrue(outval > 0) + self.assertGreater(outval, 0) elif isnan(inval): # type: ignore self.assertTrue(isnan(outval)) else: @@ -586,25 +586,27 @@ def test_float(self): def test_datetime(self): dt = datetime(2011, 7, 17, 15, 47, 42, 317509) values = [dt.date(), dt.time(), dt, dt.time(), dt] - assert isinstance(values[3], time) + self.assertIsInstance(values[3], time) + assert isinstance(values[3], time) # type guard values[3] = values[3].replace(tzinfo=timezone.utc) - assert isinstance(values[4], datetime) + self.assertIsInstance(values[4], datetime) + assert isinstance(values[4], datetime) # type guard values[4] = values[4].replace(tzinfo=timezone.utc) - d = (dt.year, dt.month, dt.day) - t = (dt.hour, dt.minute, dt.second, dt.microsecond) - z = (timezone.utc,) + da = (dt.year, dt.month, dt.day) + ti = (dt.hour, dt.minute, dt.second, dt.microsecond) + tz = (timezone.utc,) inputs = [ # input as objects values, # input as text [v.isoformat() for v in values], # type: ignore # # input using type helpers - [pgdb.Date(*d), pgdb.Time(*t), - pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), - pgdb.Timestamp(*(d + t + z))] + [pgdb.Date(*da), pgdb.Time(*ti), + pgdb.Timestamp(*(da + ti)), pgdb.Time(*(ti + tz)), + pgdb.Timestamp(*(da + ti + tz))] ] table = self.table_prefix + 'booze' - con = self._connect() + con: pgdb.Connection = self._connect() try: cur = con.cursor() cur.execute("set timezone = UTC") @@ -624,7 +626,8 @@ def test_datetime(self): " values (%s,%s,%s,%s,%s)", params) cur.execute(f"select * from {table}") d = cur.description - assert isinstance(d, list) + self.assertIsInstance(d, list) + assert d is not None # type guard for i in range(5): tc = d[i].type_code self.assertEqual(tc, pgdb.DATETIME) @@ -855,8 +858,8 @@ def test_custom_type(self): con.close() def test_set_decimal_type(self): - decimal_type = pgdb.decimal_type() - self.assertTrue(decimal_type is not None and callable(decimal_type)) + from pgdb.cast import decimal_type + self.assertIs(decimal_type(), Decimal) con = self._connect() try: cur = con.cursor() @@ -870,19 +873,19 @@ def __init__(self, value: Any) -> None: def __str__(self) -> str: return str(self.value).replace('.', ',') - self.assertTrue(pgdb.decimal_type(CustomDecimal) is CustomDecimal) + self.assertIs(decimal_type(CustomDecimal), CustomDecimal) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] - self.assertTrue(isinstance(value, CustomDecimal)) + self.assertIsInstance(value, CustomDecimal) self.assertEqual(str(value), '4,25') # change decimal type again to float - self.assertTrue(pgdb.decimal_type(float) is float) + self.assertIs(decimal_type(float), float) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] # the connection still uses the old setting - self.assertTrue(isinstance(value, str)) + self.assertIsInstance(value, str) self.assertEqual(str(value), '4,25') # bust the cache for type functions for the connection con.type_cache.reset_typecast() @@ -890,12 +893,12 @@ def __str__(self) -> str: self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] # now the connection uses the new setting - self.assertTrue(isinstance(value, float)) + self.assertIsInstance(value, float) self.assertEqual(value, 4.25) finally: con.close() - pgdb.decimal_type(decimal_type) - self.assertTrue(pgdb.decimal_type() is decimal_type) + decimal_type(Decimal) + self.assertIs(decimal_type(), Decimal) def test_global_typecast(self): try: @@ -1272,7 +1275,7 @@ def test_connection_as_contextmanager(self): cur.execute(f"insert into {table} values (3)") cur.execute(f"insert into {table} values (4)") except con.IntegrityError as error: - self.assertTrue('check' in str(error).lower()) + self.assertIn('check', str(error).lower()) with con: cur.execute(f"insert into {table} values (5)") cur.execute(f"insert into {table} values (6)") @@ -1325,11 +1328,11 @@ def test_pgdb_type(self): self.assertEqual('int8', pgdb.INTEGER) self.assertNotEqual('int4', pgdb.LONG) self.assertEqual('int8', pgdb.LONG) - self.assertTrue('char' in pgdb.STRING) - self.assertTrue(pgdb.NUMERIC <= pgdb.NUMBER) - self.assertTrue(pgdb.NUMBER >= pgdb.INTEGER) - self.assertTrue(pgdb.TIME <= pgdb.DATETIME) - self.assertTrue(pgdb.DATETIME >= pgdb.DATE) + self.assertIn('char', pgdb.STRING) + self.assertLess(pgdb.NUMERIC, pgdb.NUMBER) + self.assertGreaterEqual(pgdb.NUMBER, pgdb.INTEGER) + self.assertLessEqual(pgdb.TIME, pgdb.DATETIME) + self.assertGreaterEqual(pgdb.DATETIME, pgdb.DATE) self.assertEqual(pgdb.ARRAY, pgdb.ARRAY) self.assertNotEqual(pgdb.ARRAY, pgdb.STRING) self.assertEqual('_char', pgdb.ARRAY) @@ -1349,12 +1352,13 @@ def test_no_close(self): row = cur.fetchone() self.assertEqual(row, data) - def test_set_row_factory_size(self): + def test_change_row_factory_cache_size(self): + from pg import RowCache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] con = self._connect() cur = con.cursor() for maxsize in (None, 0, 1, 2, 3, 10, 1024): - pgdb.set_row_factory_size(maxsize) + RowCache.change_size(maxsize) for _i in range(3): for q in queries: cur.execute(q) @@ -1365,11 +1369,11 @@ def test_set_row_factory_size(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - info = pgdb._row_factory.cache_info() + info = RowCache.row_factory.cache_info() self.assertEqual(info.maxsize, maxsize) self.assertEqual(info.hits + info.misses, 6) - self.assertEqual( - info.hits, 0 if maxsize is not None and maxsize < 2 else 4) + self.assertEqual(info.hits, + 0 if maxsize is not None and maxsize < 2 else 4) def test_memory_leaks(self): ids: set = set() From b08b775cf77bd2cb8275429897f0dcb5dcd61fe9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 11:59:52 +0200 Subject: [PATCH 159/194] Improve distribution files wording --- docs/download/files.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/download/files.rst b/docs/download/files.rst index f5e7a523..fc3ad26f 100644 --- a/docs/download/files.rst +++ b/docs/download/files.rst @@ -3,11 +3,11 @@ Distribution files ============== = -pg/ the "classic" PyGreSQL module +pg/ the "classic" PyGreSQL package pgdb/ a DB-SIG DB-API 2.0 compliant API wrapper for PyGreSQL -ext/ the source files for the C extension +ext/ the source files for the C extension module docs/ the documentation directory From c70b726984e2320883ac14451bcce2381324cc26 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 14:04:46 +0200 Subject: [PATCH 160/194] Improve typing of DB wrapper init method --- pg/db.py | 46 +++++++++++++++++++++++---------- tests/test_classic_dbwrapper.py | 7 ++--- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/pg/db.py b/pg/db.py index ce7915f8..03f7d7a4 100644 --- a/pg/db.py +++ b/pg/db.py @@ -6,7 +6,7 @@ from json import dumps as jsonencode from json import loads as jsondecode from operator import itemgetter -from typing import Any, Callable, Iterator, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence, overload from . import Connection, connect from .adapt import Adapter, DbTypes @@ -23,6 +23,9 @@ from .helpers import namediter, oid_key, quote_if_unqualified from .notify import NotificationHandler +if TYPE_CHECKING: + from pgdb.connection import Connection as DbApi2Connection + __all__ = ['DB'] # The actual PostgreSQL database connection interface: @@ -33,33 +36,48 @@ class DB: db: Connection | None = None # invalid fallback for underlying connection _db_args: Any # either the connect args or the underlying connection - def __init__(self, *args: Any, **kw: Any) -> None: + @overload + def __init__(self, dbname: str | None = None, + host: str | None = None, port: int = -1, + opt: str | None = None, + user: str | None = None, passwd: str | None = None, + nowait: bool = False) -> None: + ... + + @overload + def __init__(self, db: Connection | DB | DbApi2Connection) -> None: + ... + + def __init__(self, *args: Any, **kw: Any) -> None: """Create a new connection. You can pass either the connection parameters or an existing - _pg or pgdb connection. This allows you to use the methods - of the classic pg interface with a DB-API 2 pgdb connection. + pg or pgdb Connection. This allows you to use the methods + of the classic pg interface with a DB-API 2 pgdb Connection. """ - if not args and len(kw) == 1: + if kw: db = kw.get('db') - elif not kw and len(args) == 1: + if db is not None and (args or len(kw) > 1): + raise TypeError("Conflicting connection parameters") + elif len(args) == 1 and not isinstance(args[0], str): db = args[0] else: db = None if db: if isinstance(db, DB): - db = db.db + db = db.db # allow db to be a wrapped Connection else: with suppress(AttributeError): - # noinspection PyUnresolvedReferences - db = db._cnx - if not db or not hasattr(db, 'db') or not hasattr(db, 'query'): + db = db._cnx # allow db to be a pgdb Connection + if not isinstance(db, Connection): + raise TypeError( + "The 'db' argument must be a valid database connection.") + self._db_args = db + self._closeable = False + else: db = connect(*args, **kw) self._db_args = args, kw self._closeable = True - else: - self._db_args = db - self._closeable = False self.db = db self.dbname = db.db self._regtypes = False @@ -97,7 +115,7 @@ def __init__(self, *args: Any, **kw: Any) -> None: self.debug: Any = None def __getattr__(self, name: str) -> Any: - """Get the specified attritbute of the connection.""" + """Get the specified attribute of the connection.""" # All undefined members are same as in underlying connection: if self.db: return getattr(self.db, name) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 0755d95e..e53617dd 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -301,12 +301,13 @@ def test_existing_connection(self): self.assertIsNone(db.db) db = pg.DB(self.db) self.assertEqual(self.db.db, db.db) + assert self.db.db is not None db = pg.DB(db=self.db.db) self.assertEqual(self.db.db, db.db) def test_existing_db_api2_connection(self): - class DBApi2Con: + class FakeDbApi2Connection: def __init__(self, cnx): self._cnx = cnx @@ -314,8 +315,8 @@ def __init__(self, cnx): def close(self): self._cnx.close() - db2 = DBApi2Con(self.db.db) - db = pg.DB(db2) + db2 = FakeDbApi2Connection(self.db.db) + db = pg.DB(db2) # type: ignore self.assertEqual(self.db.db, db.db) db.close() self.assertIsNone(db.db) From 17c42afbf46196510a7eacb822acad16de82948a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 14:13:44 +0200 Subject: [PATCH 161/194] Use different docstrings for overloaded methods --- pg/db.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pg/db.py b/pg/db.py index 03f7d7a4..d541ac54 100644 --- a/pg/db.py +++ b/pg/db.py @@ -42,10 +42,12 @@ def __init__(self, dbname: str | None = None, opt: str | None = None, user: str | None = None, passwd: str | None = None, nowait: bool = False) -> None: + """Create a new connection using the specified parameters.""" ... @overload def __init__(self, db: Connection | DB | DbApi2Connection) -> None: + """Create a connection wrapper based on an existing connection.""" ... def __init__(self, *args: Any, **kw: Any) -> None: From 7cc9c879581299042af725de1f31ef869429c7b8 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 14:15:46 +0200 Subject: [PATCH 162/194] Actually overloaded methods shouldn't have docstrings --- pg/db.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pg/db.py b/pg/db.py index d541ac54..a13ea357 100644 --- a/pg/db.py +++ b/pg/db.py @@ -42,13 +42,11 @@ def __init__(self, dbname: str | None = None, opt: str | None = None, user: str | None = None, passwd: str | None = None, nowait: bool = False) -> None: - """Create a new connection using the specified parameters.""" - ... + ... # create a new connection using the specified parameters @overload def __init__(self, db: Connection | DB | DbApi2Connection) -> None: - """Create a connection wrapper based on an existing connection.""" - ... + ... # create a connection wrapper based on an existing connection def __init__(self, *args: Any, **kw: Any) -> None: """Create a new connection. From cc781024cce02bc7fb732a387a9db7dc41631eeb Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 16:16:30 +0200 Subject: [PATCH 163/194] Add immediately wrapped methods These methods now also check that the underlying connection is still valid, and they allow proper typing and auto completion for wrapped connections. --- ext/pgmodule.c | 4 +- pg/_pg.pyi | 5 +- pg/core.py | 3 +- pg/db.py | 153 ++++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 160 insertions(+), 5 deletions(-) diff --git a/ext/pgmodule.c b/ext/pgmodule.c index 546c5cc5..761ae1b7 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -21,7 +21,7 @@ static PyObject *Error, *Warning, *InterfaceError, *DatabaseError, *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, *DataError, *NotSupportedError, *InvalidResultError, *NoResultError, - *MultipleResultsError, *Connection, *Query; + *MultipleResultsError, *Connection, *Query, *LargeObject; #define _TOSTRING(x) #x #define TOSTRING(x) _TOSTRING(x) @@ -1310,6 +1310,8 @@ PyInit__pg(void) PyDict_SetItemString(dict, "Connection", Connection); Query = (PyObject *)&queryType; PyDict_SetItemString(dict, "Query", Query); + LargeObject = (PyObject *)&largeType; + PyDict_SetItemString(dict, "LargeObject", LargeObject); /* Make the version available */ s = PyUnicode_FromString(PyPgVersion); diff --git a/pg/_pg.pyi b/pg/_pg.pyi index 70f6e37e..b14bd5fc 100644 --- a/pg/_pg.pyi +++ b/pg/_pg.pyi @@ -4,7 +4,10 @@ from __future__ import annotations from typing import Any, Callable, Iterable, Sequence, TypeVar -AnyStr = TypeVar('AnyStr', str, bytes, str | bytes) +try: + AnyStr = TypeVar('AnyStr', str, bytes, str | bytes) +except TypeError: # Python < 3.10 + AnyStr = Any # type: ignore SomeNamedTuple = Any # alias for accessing arbitrary named tuples version: str diff --git a/pg/core.py b/pg/core.py index 3eb8f745..e20bdbd0 100644 --- a/pg/core.py +++ b/pg/core.py @@ -62,6 +62,7 @@ InterfaceError, InternalError, InvalidResultError, + LargeObject, MultipleResultsError, NoResultError, NotSupportedError, @@ -113,7 +114,7 @@ 'InvalidResultError', 'MultipleResultsError', 'NoResultError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', - 'Connection', 'Query', + 'Connection', 'Query', 'LargeObject', 'INV_READ', 'INV_WRITE', 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', diff --git a/pg/db.py b/pg/db.py index a13ea357..f824cc9d 100644 --- a/pg/db.py +++ b/pg/db.py @@ -6,13 +6,22 @@ from json import dumps as jsonencode from json import loads as jsondecode from operator import itemgetter -from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + Sequence, + TypeVar, + overload, +) from . import Connection, connect from .adapt import Adapter, DbTypes from .attrs import AttrDict from .core import ( InternalError, + LargeObject, ProgrammingError, Query, get_bool, @@ -26,12 +35,32 @@ if TYPE_CHECKING: from pgdb.connection import Connection as DbApi2Connection +try: + AnyStr = TypeVar('AnyStr', str, bytes, str | bytes) +except TypeError: # Python < 3.10 + AnyStr = Any # type: ignore + __all__ = ['DB'] + # The actual PostgreSQL database connection interface: class DB: - """Wrapper class for the _pg connection type.""" + """Wrapper class for the core connection type.""" + + dbname: str + host: str + port: int + options: str + error: str + status: int + user : str + protocol_version: int + server_version: int + socket: int + backend_pid: int + ssl_in_use: bool + ssl_attributes: dict[str, str | None] db: Connection | None = None # invalid fallback for underlying connection _db_args: Any # either the connect args or the underlying connection @@ -1326,6 +1355,126 @@ def notification_handler(self, event: str, callback: Callable, return NotificationHandler(self, event, callback, arg_dict, timeout, stop_event) + # immediately wrapped methods + + def send_query(self, cmd: str, args: Sequence | None = None) -> Query: + """Create a new asynchronous query object for this connection.""" + if args is None: + return self._valid_db.send_query(cmd) + return self._valid_db.send_query(cmd, args) + + def poll(self) -> int: + """Complete an asynchronous connection and get its state.""" + return self._valid_db.poll() + + def cancel(self) -> None: + """Abandon processing of current SQL command.""" + self._valid_db.cancel() + + def fileno(self) -> int: + """Get the socket used to connect to the database.""" + return self._valid_db.fileno() + + def get_cast_hook(self) -> Callable | None: + """Get the function that handles all external typecasting.""" + return self._valid_db.get_cast_hook() + + def set_cast_hook(self, hook: Callable | None) -> None: + """Set a function that will handle all external typecasting.""" + self._valid_db.set_cast_hook(hook) + + def get_notice_receiver(self) -> Callable | None: + """Get the current notice receiver.""" + return self._valid_db.get_notice_receiver() + + def set_notice_receiver(self, receiver: Callable | None) -> None: + """Set a custom notice receiver.""" + self._valid_db.set_notice_receiver(receiver) + + def getnotify(self) -> tuple[str, int, str] | None: + """Get the last notify from the server.""" + return self._valid_db.getnotify() + + def inserttable(self, table: str, values: Sequence[list|tuple], + columns: list[str] | tuple[str, ...] | None = None) -> int: + """Insert a Python iterable into a database table.""" + if columns is None: + return self._valid_db.inserttable(table, values) + return self._valid_db.inserttable(table, values, columns) + + def transaction(self) -> int: + """Get the current in-transaction status of the server. + + The status returned by this method can be TRANS_IDLE (currently idle), + TRANS_ACTIVE (a command is in progress), TRANS_INTRANS (idle, in a + valid transaction block), or TRANS_INERROR (idle, in a failed + transaction block). TRANS_UNKNOWN is reported if the connection is + bad. The status TRANS_ACTIVE is reported only when a query has been + sent to the server and not yet completed. + """ + return self._valid_db.transaction() + + def parameter(self, name: str) -> str | None: + """Look up a current parameter setting of the server.""" + return self._valid_db.parameter(name) + + + def date_format(self) -> str: + """Look up the date format currently being used by the database.""" + return self._valid_db.date_format() + + def escape_literal(self, s: AnyStr) -> AnyStr: + """Escape a literal constant for use within SQL.""" + return self._valid_db.escape_literal(s) + + def escape_identifier(self, s: AnyStr) -> AnyStr: + """Escape an identifier for use within SQL.""" + return self._valid_db.escape_identifier(s) + + def escape_string(self, s: AnyStr) -> AnyStr: + """Escape a string for use within SQL.""" + return self._valid_db.escape_string(s) + + def escape_bytea(self, s: AnyStr) -> AnyStr: + """Escape binary data for use within SQL as type 'bytea'.""" + return self._valid_db.escape_bytea(s) + + def putline(self, line: str) -> None: + """Write a line to the server socket.""" + self._valid_db.putline(line) + + def getline(self) -> str: + """Get a line from server socket.""" + return self._valid_db.getline() + + def endcopy(self) -> None: + """Synchronize client and server.""" + self._valid_db.endcopy() + + def set_non_blocking(self, nb: bool) -> None: + """Set the non-blocking mode of the connection.""" + self._valid_db.set_non_blocking(nb) + + def is_non_blocking(self) -> bool: + """Get the non-blocking mode of the connection.""" + return self._valid_db.is_non_blocking() + + def locreate(self, mode: int) -> LargeObject: + """Create a large object in the database. + + The valid values for 'mode' parameter are defined as the module level + constants INV_READ and INV_WRITE. + """ + return self._valid_db.locreate(mode) + + def getlo(self, oid: int) -> LargeObject: + """Build a large object from given oid.""" + return self._valid_db.getlo(oid) + + def loimport(self, filename: str) -> LargeObject: + """Import a file to a large object.""" + return self._valid_db.loimport(filename) + class _MemoryQuery: """Class that embodies a given query result.""" From 20ce949bd8428eb416e85b45263dfd3f986865f7 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 16:55:20 +0200 Subject: [PATCH 164/194] Support Python 3.12 and PostgreSQL 16 --- .bumpversion.cfg | 2 +- .devcontainer/provision.sh | 7 ++++++- README.rst | 6 ++++++ docs/about.rst | 4 ++-- docs/conf.py | 2 +- docs/contents/changelog.rst | 5 +++-- docs/contents/install.rst | 2 +- pyproject.toml | 3 ++- setup.py | 3 ++- tests/test_classic_connection.py | 3 ++- tests/test_classic_dbwrapper.py | 3 ++- tests/test_classic_functions.py | 2 +- tox.ini | 2 +- 13 files changed, 30 insertions(+), 14 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 769d02cf..1e499975 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 6.0 +current_version = 6.0b1 commit = False tag = False diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index c780e7df..05a681e4 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -26,6 +26,7 @@ sudo apt-get install -y python3.8 python3.8-dev python3.8-distutils sudo apt-get install -y python3.9 python3.9-dev python3.9-distutils sudo apt-get install -y python3.10 python3.10-dev python3.10-distutils sudo apt-get install -y python3.11 python3.11-dev python3.11-distutils +sudo apt-get install -y python3.12 python3.12-dev python3.12-distutils # install build and testing tool @@ -43,7 +44,7 @@ sudo apt-get install -y tox clang-format sudo apt-get install -y postgresql libpq-dev -for pghost in pg10 pg12 pg14 pg15 +for pghost in pg10 pg12 pg14 pg15 pg16 do export PGHOST=$pghost export PGDATABASE=postgres @@ -76,3 +77,7 @@ do psql -c "create extension hstore" test_latin9 psql -c "create extension hstore" test_cyrillic done + +export PGDATABASE=test +export PGUSER=test +export PGPASSWORD=test diff --git a/README.rst b/README.rst index a010b944..e9f9465c 100644 --- a/README.rst +++ b/README.rst @@ -18,6 +18,9 @@ The following Python versions are supported: * PyGreSQL 5.x: Python 2 and Python 3 * PyGreSQL 6.x and newer: Python 3 only +The current version of PyGreSQL supports Python versions 3.7 to 3.12 +and PostgreSQL versions 10 to 16 on the server. + Installation ------------ @@ -28,6 +31,9 @@ The simplest way to install PyGreSQL is to type:: For other ways of installing PyGreSQL and requirements, see the documentation. +Note that PyGreSQL also requires the libpq shared library to be +installed and accessible on the client machine. + Documentation ------------- diff --git a/docs/about.rst b/docs/about.rst index 8235e5cc..18c6b7a6 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -39,6 +39,6 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL |version| needs PostgreSQL 10 to 15, and Python -3.7 to 3.11. If you need to support older PostgreSQL or Python versions, +The current version PyGreSQL |version| needs PostgreSQL 10 to 16, and Python +3.7 to 3.12. If you need to support older PostgreSQL or Python versions, you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/conf.py b/docs/conf.py index 9dd604f2..48cb7dc0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,7 +10,7 @@ author = 'The PyGreSQL team' copyright = '2023, ' + author -version = release = '6.0' +version = release = '6.0b1' language = 'en' diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 077893a2..9f35f716 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,8 +1,9 @@ ChangeLog ========= -Version 6.0 (to be released) ----------------------------- +Version 6.0b1 (2023-09-06) +-------------------------- +- Officially support Python 3.12 and PostgreSQL 16 (tested with rc versions). - Removed support for Python versions older than 3.7 (released June 2017) and PostgreSQL older than version 10 (released October 2017). - Converted the standalone modules `pg` and `pgdb` to packages with diff --git a/docs/contents/install.rst b/docs/contents/install.rst index f447abc3..7d28ea59 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -14,7 +14,7 @@ On Windows, you also need to make sure that the directory that contains ``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -3.7 to 3.11, and PostgreSQL versions 10 to 15. +3.7 to 3.12, and PostgreSQL versions 10 to 16. PyGreSQL will be installed as two packages named ``pg`` (for the classic interface) and ``pgdb`` (for the DB API 2 compliant interface). The former diff --git a/pyproject.toml b/pyproject.toml index e289b38f..30d255e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.0" +version = "6.0b1" requires-python = ">=3.7" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, @@ -22,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: SQL", "Topic :: Database", "Topic :: Database :: Front-Ends", diff --git a/setup.py b/setup.py index 4fd39c56..d0f70ea0 100755 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -version = '6.0' +version = '6.0b1' if not (3, 7) <= sys.version_info[:2] < (4, 0): raise Exception( @@ -152,6 +152,7 @@ def finalize_options(self): 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Programming Language :: SQL', 'Topic :: Database', 'Topic :: Database :: Front-Ends', diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index dcb7a5e2..be1b5a42 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -174,7 +174,8 @@ def test_attribute_protocol_version(self): def test_attribute_server_version(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertTrue(100000 <= server_version < 160000) + self.assertGreaterEqual(server_version, 100000) + self.assertLess(server_version, 170000) def test_attribute_socket(self): socket = self.connection.socket diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index e53617dd..d1224a53 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -168,7 +168,8 @@ def test_attribute_protocol_version(self): def test_attribute_server_version(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertTrue(100000 <= server_version < 160000) + self.assertGreaterEqual(server_version, 100000) + self.assertLess(server_version, 170000) self.assertEqual(server_version, self.db.db.server_version) def test_attribute_socket(self): diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 01ed752e..4351f794 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -125,7 +125,7 @@ def test_pqlib_version(self): v = pg.get_pqlib_version() self.assertIsInstance(v, int) self.assertGreater(v, 100000) - self.assertLess(v, 160000) + self.assertLess(v, 170000) class TestParseArray(unittest.TestCase): diff --git a/tox.ini b/tox.ini index eae93234..b86f9fea 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11},ruff,mypy,cformat,docs +envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.11 From 741f734c73a2baaa667ed5980b53fb98090a4f1b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 17:44:35 +0200 Subject: [PATCH 165/194] Install current build tools for development --- .devcontainer/provision.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 05a681e4..7cb14be0 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -30,11 +30,11 @@ sudo apt-get install -y python3.12 python3.12-dev python3.12-distutils # install build and testing tool -python3.7 -m pip install build -python3.8 -m pip install build -python3.9 -m pip install build -python3.10 -m pip install build -python3.11 -m pip install build +python3.7 -m pip install -U pip setuptools wheel build +python3.8 -m pip install -U pip setuptools wheel build +python3.9 -m pip install -U pip setuptools wheel build +python3.10 -m pip install -U pip setuptools wheel build +python3.11 -m pip install -U pip setuptools wheel build pip install ruff From 7d93eb222a48ff0b14012f684b316d338cad5dd9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 17:47:18 +0200 Subject: [PATCH 166/194] Avoid segfault when there is a poll error --- ext/pgconn.c | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ext/pgconn.c b/ext/pgconn.c index 10e5b780..9ffc0009 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -45,7 +45,7 @@ conn_getattr(connObject *self, PyObject *nameobj) /* postmaster host */ if (!strcmp(name, "host")) { char *r = PQhost(self->cnx); - if (!r || r[0] == '/') /* Pg >= 9.6 can return a Unix socket path */ + if (!r || r[0] == '/') /* this can return a Unix socket path */ r = "localhost"; return PyUnicode_FromString(r); } @@ -1577,7 +1577,6 @@ conn_poll(connObject *self, PyObject *noargs) if (rc == PGRES_POLLING_FAILED) { set_error(InternalError, "Polling failed", self->cnx, NULL); - Py_XDECREF(self); return NULL; } From b96d64f49be97731041948df12c0e868d8ea42d2 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 18:38:34 +0200 Subject: [PATCH 167/194] Add coverage to tox file --- tox.ini | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tox.ini b/tox.ini index b86f9fea..dd747abe 100644 --- a/tox.ini +++ b/tox.ini @@ -38,6 +38,14 @@ deps = commands = python -m build -n -C strict -C memory-size +[testenv:coverage] +basepython = python3.11 +deps = + coverage>=7,<8 +commands = + coverage run -m unittest discover + coverage html + [testenv] passenv = PG* From f8e79fb17640f5f5e2e5da735b550bd3339b2d38 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 18:39:04 +0200 Subject: [PATCH 168/194] Update bump file --- .bumpversion.cfg | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 1e499975..f9acd8ee 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -12,6 +12,10 @@ serialize = search = version = '{current_version}' replace = version = '{new_version}' +[bumpversion:file:pyproject.toml] +search = version = "{current_version}" +replace = version = "{new_version}" + [bumpversion:file:docs/conf.py] search = version = release = '{current_version}' replace = version = release = '{new_version}' From 439cbdd77b6be77281eb0eafcab1d3130226ec82 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 19:52:34 +0200 Subject: [PATCH 169/194] Use consistent project urls --- .devcontainer/provision.sh | 1 + pyproject.toml | 14 +++++++------- setup.py | 2 +- tox.ini | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 7cb14be0..09acd893 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -39,6 +39,7 @@ python3.11 -m pip install -U pip setuptools wheel build pip install ruff sudo apt-get install -y tox clang-format +pip install -U tox # install PostgreSQL client tools diff --git a/pyproject.toml b/pyproject.toml index 30d255e4..fdbbcbea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,13 +33,13 @@ classifiers = [ file = "LICENSE.txt" [project.urls] -homepage = "https://pygresql.github.io/" -documentation = "https://pygresql.github.io/contents/" -source = "https://github.com/PyGreSQL/PyGreSQL" -issues = "https://github.com/PyGreSQL/PyGreSQL/issues/" -changelog = "https://pygresql.github.io/contents/changelog.html" -download = "https://pygresql.github.io/download/" -"mailing list" = "https://mail.vex.net/mailman/listinfo/pygresql" +Homepage = "https://pygresql.github.io/" +Documentation = "https://pygresql.github.io/contents/" +"Source Code" = "https://github.com/PyGreSQL/PyGreSQL" +"Issue Tracker" = "https://github.com/PyGreSQL/PyGreSQL/issues/" +Changelog = "https://pygresql.github.io/contents/changelog.html" +Download = "https://pygresql.github.io/download/" +"Mailing List" = "https://mail.vex.net/mailman/listinfo/pygresql" [tool.ruff] target-version = "py37" diff --git a/setup.py b/setup.py index d0f70ea0..c2d72a61 100755 --- a/setup.py +++ b/setup.py @@ -133,7 +133,7 @@ def finalize_options(self): author="D'Arcy J. M. Cain", author_email="darcy@PyGreSQL.org", url='https://pygresql.github.io/', - download_url='https://pygresql.github.io/contents/download/', + download_url='https://pygresql.github.io/download/', project_urls={ 'Documentation': 'https://pygresql.github.io/contents/', 'Issue Tracker': 'https://github.com/PyGreSQL/PyGreSQL/issues/', diff --git a/tox.ini b/tox.ini index dd747abe..c58f40b1 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = wheel>=0.41 build>=0.10 commands = - python -m build -n -C strict -C memory-size + python -m build -s -n -C strict -C memory-size [testenv:coverage] basepython = python3.11 From e73f4ae9a2ed81b81cc833deae97f807fd25f420 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 16:07:27 +0000 Subject: [PATCH 170/194] Test with Python 3.12 and Postgres 16 on GitHub --- .github/workflows/tests.yml | 12 +++++++----- tox.ini | 6 +++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ca8e4a36..43da55df 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,13 +20,15 @@ jobs: - { python: "3.9", postgres: "13" } - { python: "3.10", postgres: "14" } - { python: "3.11", postgres: "15" } + - { python: "3.12", postgres: "16" } # Opposite extremes of the supported Py/PG range, other architecture - - { python: "3.7", postgres: "15", architecture: "x86" } - - { python: "3.8", postgres: "14", architecture: "x86" } - - { python: "3.9", postgres: "13", architecture: "x86" } - - { python: "3.10", postgres: "12", architecture: "x86" } - - { python: "3.11", postgres: "11", architecture: "x86" } + - { python: "3.7", postgres: "16", architecture: "x86" } + - { python: "3.8", postgres: "15", architecture: "x86" } + - { python: "3.9", postgres: "14", architecture: "x86" } + - { python: "3.10", postgres: "13", architecture: "x86" } + - { python: "3.11", postgres: "12", architecture: "x86" } + - { python: "3.12", postgres: "11", architecture: "x86" } env: PYGRESQL_DB: test diff --git a/tox.ini b/tox.ini index c58f40b1..fd36f2a2 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,7 @@ envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.11 -deps = ruff>=0.0.287 +deps = ruff>=0.0.292 commands = ruff setup.py pg pgdb tests @@ -33,8 +33,8 @@ commands = basepython = python3.11 deps = setuptools>=68 - wheel>=0.41 - build>=0.10 + wheel>=0.41,<1 + build>=1,<2 commands = python -m build -s -n -C strict -C memory-size From a581e0448244f439f7bbc3d66ee88879e1da47f6 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 17:25:08 +0000 Subject: [PATCH 171/194] Update checkout action --- .github/workflows/docs.yml | 2 +- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index aae221a0..7d1ba05a 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -11,7 +11,7 @@ jobs: steps: - name: CHeck out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index dad89096..267c54c2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,7 +14,7 @@ jobs: steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install tox run: pip install tox - name: Setup Python diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 43da55df..31a48265 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -53,7 +53,7 @@ jobs: steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install tox run: pip install tox - name: Setup Python From cdf9f427ee936a838992f28b478bb31171d7da44 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 20:16:03 +0200 Subject: [PATCH 172/194] In some test environments, there is no SSL support The test started to break on GitHub in September 2023. Might have to do with changes in the Ubuntu docker image. --- tests/test_classic_connection.py | 7 ++++--- tests/test_classic_dbwrapper.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index be1b5a42..3f9427b2 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -195,9 +195,10 @@ def test_attribute_ssl_in_use(self): def test_attribute_ssl_attributes(self): ssl_attributes = self.connection.ssl_attributes self.assertIsInstance(ssl_attributes, dict) - self.assertEqual(ssl_attributes, { - 'cipher': None, 'compression': None, 'key_bits': None, - 'library': None, 'protocol': None}) + if ssl_attributes: + self.assertEqual(ssl_attributes, { + 'cipher': None, 'compression': None, 'key_bits': None, + 'library': None, 'protocol': None}) def test_attribute_status(self): status_ok = 1 diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index d1224a53..8aa691f5 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -190,9 +190,10 @@ def test_attribute_ssl_in_use(self): def test_attribute_ssl_attributes(self): ssl_attributes = self.db.ssl_attributes self.assertIsInstance(ssl_attributes, dict) - self.assertEqual(ssl_attributes, { - 'cipher': None, 'compression': None, 'key_bits': None, - 'library': None, 'protocol': None}) + if ssl_attributes: + self.assertEqual(ssl_attributes, { + 'cipher': None, 'compression': None, 'key_bits': None, + 'library': None, 'protocol': None}) def test_attribute_status(self): status_ok = 1 From 3e0de8e7e7d48f34141b473ba45e2b6c0bfef6f3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 18:34:34 +0000 Subject: [PATCH 173/194] Python 3.12 needs setuptools --- tox.ini | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tox.ini b/tox.ini index fd36f2a2..3917f73c 100644 --- a/tox.ini +++ b/tox.ini @@ -50,6 +50,8 @@ commands = passenv = PG* PYGRESQL_* +deps = + setuptools>=68 commands = python setup.py clean --all build_ext --force --inplace --strict --memory-size python -m unittest {posargs:discover} From 0587d593fe97b42a94735e87bf958e296ccbe704 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 18:53:55 +0000 Subject: [PATCH 174/194] Keep version only in pyproject.toml --- .bumpversion.cfg | 21 --------------------- docs/conf.py | 10 +++++++++- pyproject.toml | 2 +- setup.py | 19 +++++++++++++++---- 4 files changed, 25 insertions(+), 27 deletions(-) delete mode 100644 .bumpversion.cfg diff --git a/.bumpversion.cfg b/.bumpversion.cfg deleted file mode 100644 index f9acd8ee..00000000 --- a/.bumpversion.cfg +++ /dev/null @@ -1,21 +0,0 @@ -[bumpversion] -current_version = 6.0b1 -commit = False -tag = False - -parse = (?P\d+)\.(?P\d+)(?:\.(?P\d+))? -serialize = - {major}.{minor}.{patch} - {major}.{minor} - -[bumpversion:file:setup.py] -search = version = '{current_version}' -replace = version = '{new_version}' - -[bumpversion:file:pyproject.toml] -search = version = "{current_version}" -replace = version = "{new_version}" - -[bumpversion:file:docs/conf.py] -search = version = release = '{current_version}' -replace = version = release = '{new_version}' diff --git a/docs/conf.py b/docs/conf.py index 48cb7dc0..1a63dac4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,7 +10,15 @@ author = 'The PyGreSQL team' copyright = '2023, ' + author -version = release = '6.0b1' +def project_version(): + with open('../pyproject.toml') as f: + for d in f: + if d.startswith("version ="): + version = d.split("=")[1].strip().strip('"') + return version + raise Exception("Cannot determine PyGreSQL version") + +version = release = project_version() language = 'en' diff --git a/pyproject.toml b/pyproject.toml index fdbbcbea..bfcac161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.0b1" +version = "6.0" requires-python = ">=3.7" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, diff --git a/setup.py b/setup.py index c2d72a61..813ecde8 100755 --- a/setup.py +++ b/setup.py @@ -19,14 +19,25 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -version = '6.0b1' - if not (3, 7) <= sys.version_info[:2] < (4, 0): raise Exception( f"Sorry, PyGreSQL {version} does not support this Python version") -with open('README.rst') as f: - long_description = f.read() +def project_version(): + with open('pyproject.toml') as f: + for d in f: + if d.startswith("version ="): + version = d.split("=")[1].strip().strip('"') + return version + raise Exception("Cannot determine PyGreSQL version") + +def project_readme(): + with open('README.rst') as f: + return f.read() + +version = project_version() + +long_description = project_readme() # For historical reasons, PyGreSQL does not install itself as a single # "pygresql" package, but as two top-level modules "pg", providing the From 7a9a6fbd9120c77c2d69d4f5975e56fde95a75a9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 19:02:09 +0000 Subject: [PATCH 175/194] Handle ruff complaints --- setup.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 813ecde8..8b1ec5dc 100755 --- a/setup.py +++ b/setup.py @@ -19,11 +19,9 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -if not (3, 7) <= sys.version_info[:2] < (4, 0): - raise Exception( - f"Sorry, PyGreSQL {version} does not support this Python version") def project_version(): + """Read the PyGreSQL version from the pyproject.toml file.""" with open('pyproject.toml') as f: for d in f: if d.startswith("version ="): @@ -31,14 +29,22 @@ def project_version(): return version raise Exception("Cannot determine PyGreSQL version") + def project_readme(): + """Get the content of the README file.""" with open('README.rst') as f: return f.read() + version = project_version() +if not (3, 7) <= sys.version_info[:2] < (4, 0): + raise Exception( + f"Sorry, PyGreSQL {version} does not support this Python version") + long_description = project_readme() + # For historical reasons, PyGreSQL does not install itself as a single # "pygresql" package, but as two top-level modules "pg", providing the # classic interface, and "pgdb" for the modern DB-API 2.0 interface. From 9bc5a1ec81ef15595645012e9493c26fd96333a9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 19:11:02 +0000 Subject: [PATCH 176/194] Update changelog --- docs/contents/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 9f35f716..ac501b56 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,12 @@ ChangeLog ========= +Version 6.0 (2023-10-03) +------------------------ +- Tested with the recent releases of Python 3.12 and PostgreSQL 16. +- Make pyproject.toml the only source of truth for the version number. +- Please also note the changes already made in version 6.0b1. + Version 6.0b1 (2023-09-06) -------------------------- - Officially support Python 3.12 and PostgreSQL 16 (tested with rc versions). From a5af0d2897f8565f7bbcf9f7fe9b67251ce185f3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 29 Feb 2024 14:53:22 +0000 Subject: [PATCH 177/194] Ignore dll files for Python --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 83732331..8b08bb41 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,7 @@ *.patch *.pid *.pstats -*.py[co] +*.py[cdo] *.so *.swp From b2e1752c1e0ff18040a2280770edf2686873eddb Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 29 Feb 2024 16:11:22 +0000 Subject: [PATCH 178/194] Properly adapt falsy JSON values (#86) --- pg/adapt.py | 2 +- tests/test_classic_dbwrapper.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pg/adapt.py b/pg/adapt.py index 9cbecaaf..2a5efaa2 100644 --- a/pg/adapt.py +++ b/pg/adapt.py @@ -333,7 +333,7 @@ def _adapt_bytea(self, v: Any) -> str: def _adapt_json(self, v: Any) -> str | None: """Adapt a json parameter.""" - if not v: + if v is None: return None if isinstance(v, str): return v diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 8aa691f5..f02955c7 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4390,6 +4390,14 @@ def test_adapt_query_typed_list_with_json(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) + def test_adapt_query_typed_list_with_empty_json(self): + format_query = self.adapter.format_query + values: Any = [None, 0, False, '', [], {}] + types = ('json',) * 6 + sql, params = format_query("select %s,%s,%s,%s,%s,%s", values, types) + self.assertEqual(sql, 'select $1,$2,$3,$4,$5,$6') + self.assertEqual(params, [None, '0', 'false', '', '[]', '{}']) + def test_adapt_query_typed_with_hstore(self): format_query = self.adapter.format_query value: Any = {'one': "it's fine", 'two': 2} From a8507e0f1f1f63c19ae7a85ba22ed5c4e2883070 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 29 Feb 2024 16:21:02 +0000 Subject: [PATCH 179/194] Update ruff and mypy --- pyproject.toml | 29 ++++++++++++++++------------- tox.ini | 6 +++--- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bfcac161..e720490b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,18 +44,6 @@ Download = "https://pygresql.github.io/download/" [tool.ruff] target-version = "py37" line-length = 79 -select = [ - "E", # pycodestyle - "F", # pyflakes - "I", # isort - "N", # pep8-naming - "UP", # pyupgrade - "D", # pydocstyle - "B", # bugbear - "S", # bandit - "SIM", # simplify - "RUF", # ruff -] exclude = [ "__pycache__", "__pypackages__", @@ -71,7 +59,22 @@ exclude = [ "venv", ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint] +select = [ + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "D", # pydocstyle + "B", # bugbear + "S", # bandit + "SIM", # simplify + "RUF", # ruff +] +ignore = ["D203", "D213"] + +[tool.ruff.lint.per-file-ignores] "tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] [tool.mypy] diff --git a/tox.ini b/tox.ini index 3917f73c..96ed1c5e 100644 --- a/tox.ini +++ b/tox.ini @@ -5,13 +5,13 @@ envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.11 -deps = ruff>=0.0.292 +deps = ruff>=0.3.0 commands = - ruff setup.py pg pgdb tests + ruff check setup.py pg pgdb tests [testenv:mypy] basepython = python3.11 -deps = mypy>=1.5.1 +deps = mypy>=1.8.0 commands = mypy pg pgdb tests From 8d5b39b35b196edaa8bfbe67de909bfeba8a794d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Apr 2024 21:49:43 +0200 Subject: [PATCH 180/194] Update lint tools and GitHub actions --- .github/workflows/docs.yml | 11 ++++++----- .github/workflows/lint.yml | 4 ++-- .github/workflows/tests.yml | 2 +- pyproject.toml | 4 ++-- tox.ini | 18 +++++++++--------- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 7d1ba05a..50248b64 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -6,16 +6,17 @@ on: - main jobs: - build: + docs: + name: Build documentation runs-on: ubuntu-22.04 steps: - - name: CHeck out repository + - name: Check out repository uses: actions/checkout@v4 - - name: Set up Python 3.11 - uses: actions/setup-python@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 with: - python-version: 3.11 + python-version: 3.12 - name: Install dependencies run: | sudo apt install libpq-dev diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 267c54c2..c32a6e58 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,13 +14,13 @@ jobs: steps: - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Install tox run: pip install tox - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.11 + python-version: 3.12 - name: Run quality checks run: tox -e ruff,mypy,cformat,docs timeout-minutes: 5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 31a48265..822fabdb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -57,7 +57,7 @@ jobs: - name: Install tox run: pip install tox - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - name: Run tests diff --git a/pyproject.toml b/pyproject.toml index e720490b..a3be7012 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ ignore = ["D203", "D213"] "tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] [tool.mypy] -python_version = "3.11" +python_version = "3.12" check_untyped_defs = true no_implicit_optional = true strict_optional = true @@ -101,5 +101,5 @@ pg = ["pg.typed"] pgdb = ["pg.typed"] [build-system] -requires = ["setuptools>=68", "wheel>=0.41"] +requires = ["setuptools>=68", "wheel>=0.42"] build-backend = "setuptools.build_meta" diff --git a/tox.ini b/tox.ini index 96ed1c5e..f703abb5 100644 --- a/tox.ini +++ b/tox.ini @@ -4,42 +4,42 @@ envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs [testenv:ruff] -basepython = python3.11 -deps = ruff>=0.3.0 +basepython = python3.12 +deps = ruff>=0.3.7 commands = ruff check setup.py pg pgdb tests [testenv:mypy] -basepython = python3.11 -deps = mypy>=1.8.0 +basepython = python3.12 +deps = mypy>=1.9.0 commands = mypy pg pgdb tests [testenv:cformat] -basepython = python3.11 +basepython = python3.12 allowlist_externals = sh commands = sh -c "! (clang-format --style=file -n ext/*.c 2>&1 | tee /dev/tty | grep format-violations)" [testenv:docs] -basepython = python3.11 +basepython = python3.12 deps = sphinx>=7,<8 commands = sphinx-build -b html -nEW docs docs/_build/html [testenv:build] -basepython = python3.11 +basepython = python3.12 deps = setuptools>=68 - wheel>=0.41,<1 + wheel>=0.42,<1 build>=1,<2 commands = python -m build -s -n -C strict -C memory-size [testenv:coverage] -basepython = python3.11 +basepython = python3.12 deps = coverage>=7,<8 commands = From 633324d45b15c23de295cd6e680ddbe365df91c7 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Apr 2024 22:05:20 +0200 Subject: [PATCH 181/194] Add docker files to repository --- .devcontainer/Dockerfile | 14 +++++++ .devcontainer/docker-compose.yml | 69 ++++++++++++++++++++++++++++++++ .gitignore | 4 -- 3 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/docker-compose.yml diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000..5aced2f4 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,14 @@ +FROM mcr.microsoft.com/devcontainers/base:jammy + +ENV PYTHONUNBUFFERED 1 + +# [Optional] If your requirements rarely change, uncomment this section to add them to the image. +# COPY requirements.txt /tmp/pip-tmp/ +# RUN pip3 --disable-pip-version-check --no-cache-dir install -r /tmp/pip-tmp/requirements.txt \ +# && rm -rf /tmp/pip-tmp + +# [Optional] Uncomment this section to install additional OS packages. +# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ +# && apt-get -y install --no-install-recommends + +CMD ["sleep", "infinity"] diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml new file mode 100644 index 00000000..61b13a7c --- /dev/null +++ b/.devcontainer/docker-compose.yml @@ -0,0 +1,69 @@ +services: + dev: + build: + context: . + dockerfile: ./Dockerfile + + env_file: dev.env + + volumes: + - ..:/workspace:cached + + command: sleep infinity + + pg10: + image: postgres:10 + restart: unless-stopped + volumes: + - postgres-data-10:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg12: + image: postgres:12 + restart: unless-stopped + volumes: + - postgres-data-12:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg14: + image: postgres:14 + restart: unless-stopped + volumes: + - postgres-data-14:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg15: + image: postgres:15 + restart: unless-stopped + volumes: + - postgres-data-15:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg16: + image: postgres:16 + restart: unless-stopped + volumes: + - postgres-data-16:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + +volumes: + postgres-data-10: + postgres-data-12: + postgres-data-14: + postgres-data-15: + postgres-data-16: diff --git a/.gitignore b/.gitignore index 8b08bb41..22c5ce3c 100644 --- a/.gitignore +++ b/.gitignore @@ -20,10 +20,6 @@ _build_doctrees/ /local/ /tests/LOCAL_*.py -docker-compose.yml -Dockerfile -Vagrantfile -Vagrantfile-* .coverage .tox/ From 8ec45a29e65f6fe4fe12e903b611db7a0236ac07 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Apr 2024 22:54:41 +0200 Subject: [PATCH 182/194] Fix mintor linting issues --- pg/db.py | 15 +++++++++------ pgdb/adapt.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pg/db.py b/pg/db.py index f824cc9d..5c8beea7 100644 --- a/pg/db.py +++ b/pg/db.py @@ -802,8 +802,9 @@ def get(self, table: str, row: Any, adapt = params.add col = self.escape_identifier what = 'oid, *' if qoid else '*' - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keyname) + where = ' AND '.join( + f'{col(k)} OPERATOR(pg_catalog.=) {adapt(row[k], attnames[k])}' + for k in keyname) if 'oid' in row: if qoid: row[qoid] = row['oid'] @@ -913,8 +914,9 @@ def update(self, table: str, row: dict[str, Any] | None = None, **kw : Any params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keynames) + where = ' AND '.join( + f'{col(k)} OPERATOR(pg_catalog.=) {adapt(row[k], attnames[k])}' + for k in keynames) if 'oid' in row: if qoid: row[qoid] = row['oid'] @@ -1103,8 +1105,9 @@ def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keynames) + where = ' AND '.join( + f'{col(k)} OPERATOR(pg_catalog.=) {adapt(row[k], attnames[k])}' + for k in keynames) if 'oid' in row: if qoid: row[qoid] = row['oid'] diff --git a/pgdb/adapt.py b/pgdb/adapt.py index 92b48a7e..b89986b6 100644 --- a/pgdb/adapt.py +++ b/pgdb/adapt.py @@ -33,7 +33,7 @@ def __new__(cls, values: str | Iterable[str]) -> DbType: """Create new type object.""" if isinstance(values, str): values = values.split() - return super().__new__(cls, values) # type: ignore + return super().__new__(cls, values) def __eq__(self, other: Any) -> bool: """Check whether types are considered equal.""" From 683a63632727d229a684d47bb264c1a9bee37b4c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Apr 2024 20:57:11 +0000 Subject: [PATCH 183/194] Update the year of the copyright --- LICENSE.txt | 2 +- docs/about.rst | 2 +- docs/conf.py | 2 +- docs/copyright.rst | 2 +- ext/pgconn.c | 2 +- ext/pginternal.c | 2 +- ext/pglarge.c | 2 +- ext/pgmodule.c | 2 +- ext/pgnotice.c | 2 +- ext/pgquery.c | 2 +- ext/pgsource.c | 2 +- pg/__init__.py | 2 +- pgdb/__init__.py | 2 +- 13 files changed, 13 insertions(+), 13 deletions(-) diff --git a/LICENSE.txt b/LICENSE.txt index eea706fe..b34bf23b 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -6,7 +6,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain -Further modifications copyright (c) 2009-2023 by the PyGreSQL Development Team +Further modifications copyright (c) 2009-2024 by the PyGreSQL Development Team PyGreSQL is released under the PostgreSQL License, a liberal Open Source license, similar to the BSD or MIT licenses: diff --git a/docs/about.rst b/docs/about.rst index 18c6b7a6..180af459 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -8,7 +8,7 @@ powerful PostgreSQL features from Python. | This software is copyright © 1995, Pascal Andre. | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2023 by the PyGreSQL team. + | Further modifications are copyright © 2009-2024 by the PyGreSQL team. | For licensing details, see the full :doc:`copyright`. **PostgreSQL** is a highly scalable, SQL compliant, open source diff --git a/docs/conf.py b/docs/conf.py index 1a63dac4..45a86cd4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -8,7 +8,7 @@ project = 'PyGreSQL' author = 'The PyGreSQL team' -copyright = '2023, ' + author +copyright = '2024, ' + author def project_version(): with open('../pyproject.toml') as f: diff --git a/docs/copyright.rst b/docs/copyright.rst index 9a8113ec..60739ef0 100644 --- a/docs/copyright.rst +++ b/docs/copyright.rst @@ -10,7 +10,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain (darcy@PyGreSQL.org) -Further modifications copyright (c) 2009-2023 by the PyGreSQL team. +Further modifications copyright (c) 2009-2024 by the PyGreSQL team. Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/ext/pgconn.c b/ext/pgconn.c index 9ffc0009..ddc958ea 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -3,7 +3,7 @@ * * The connection object - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pginternal.c b/ext/pginternal.c index 124661c1..9b3952cc 100644 --- a/ext/pginternal.c +++ b/ext/pginternal.c @@ -3,7 +3,7 @@ * * Internal functions - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pglarge.c b/ext/pglarge.c index 77455361..f19568c4 100644 --- a/ext/pglarge.c +++ b/ext/pglarge.c @@ -3,7 +3,7 @@ * * Large object support - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgmodule.c b/ext/pgmodule.c index 761ae1b7..26b916d6 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -3,7 +3,7 @@ * * This is the main file for the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgnotice.c b/ext/pgnotice.c index 0252a56f..ca051d88 100644 --- a/ext/pgnotice.c +++ b/ext/pgnotice.c @@ -3,7 +3,7 @@ * * The notice object - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgquery.c b/ext/pgquery.c index 6346497d..fe5dda47 100644 --- a/ext/pgquery.c +++ b/ext/pgquery.c @@ -3,7 +3,7 @@ * * The query object - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgsource.c b/ext/pgsource.c index 42510b30..4e197578 100644 --- a/ext/pgsource.c +++ b/ext/pgsource.c @@ -3,7 +3,7 @@ * * The source object - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pg/__init__.py b/pg/__init__.py index 37447c9e..cb4c7c34 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -4,7 +4,7 @@ # # This file contains the classic pg module. # -# Copyright (c) 2023 by the PyGreSQL Development Team +# Copyright (c) 2024 by the PyGreSQL Development Team # # The notification handler is based on pgnotify which is # Copyright (c) 2001 Ng Pheng Siong. All rights reserved. diff --git a/pgdb/__init__.py b/pgdb/__init__.py index b9a4449a..2604074a 100644 --- a/pgdb/__init__.py +++ b/pgdb/__init__.py @@ -4,7 +4,7 @@ # # This file contains the DB-API 2 compatible pgdb module. # -# Copyright (c) 2023 by the PyGreSQL Development Team +# Copyright (c) 2024 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions. From 0ab37131bce025668717735969a20667bd9bcb4b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Apr 2024 21:06:01 +0000 Subject: [PATCH 184/194] Fix GitHub action --- .github/workflows/lint.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c32a6e58..9e5c0bde 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,11 +14,11 @@ jobs: steps: - name: Check out repository - uses: actions/checkout@v5 + uses: actions/checkout@v4 - name: Install tox run: pip install tox - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.12 - name: Run quality checks From d55137969d33130e9c025dc4da77c87522008c4b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 19 Apr 2024 21:34:09 +0200 Subject: [PATCH 185/194] Bump version number --- docs/contents/changelog.rst | 4 ++++ pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index ac501b56..6afcb1e8 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,10 @@ ChangeLog ========= +Version 6.0.1 (2024-04-19) +-------------------------- +- Properly adapt falsy JSON values (#86) + Version 6.0 (2023-10-03) ------------------------ - Tested with the recent releases of Python 3.12 and PostgreSQL 16. diff --git a/pyproject.toml b/pyproject.toml index a3be7012..ef1de2a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.0" +version = "6.0.1" requires-python = ">=3.7" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, From 40ba811a3b424088cebcc7338e33a3f265c0fb87 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 19 Apr 2024 20:33:25 +0000 Subject: [PATCH 186/194] Fix issues with provision.sh --- .devcontainer/devcontainer.json | 2 +- .devcontainer/provision.sh | 37 ++++++++++++++++++--------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index b9fbaaeb..0333b8e6 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -56,7 +56,7 @@ // Use 'forwardPorts' to make a list of ports inside the container available locally. // "forwardPorts": [], // Use 'postCreateCommand' to run commands after the container is created. - "postCreateCommand": "bash /workspace/.devcontainer/provision.sh" + "postCreateCommand": "sudo bash /workspace/.devcontainer/provision.sh" // Configure tool-specific properties. // "customizations": {}, // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 09acd893..5515b687 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -4,46 +4,49 @@ export DEBIAN_FRONTEND=noninteractive -sudo apt-get update -sudo apt-get -y upgrade +apt-get update +apt-get -y upgrade # install base utilities and configure time zone -sudo ln -fs /usr/share/zoneinfo/UTC /etc/localtime -sudo apt-get install -y apt-utils software-properties-common -sudo apt-get install -y tzdata -sudo dpkg-reconfigure --frontend noninteractive tzdata +ln -fs /usr/share/zoneinfo/UTC /etc/localtime +apt-get install -y apt-utils software-properties-common +ap-get install -y tzdata +dpkg-reconfigure --frontend noninteractive tzdata -sudo apt-get install -y rpm wget zip +apt-get install -y rpm wget zip # install all supported Python versions -sudo add-apt-repository -y ppa:deadsnakes/ppa -sudo apt-get update +add-apt-repository -y ppa:deadsnakes/ppa +apt-get update -sudo apt-get install -y python3.7 python3.7-dev python3.7-distutils -sudo apt-get install -y python3.8 python3.8-dev python3.8-distutils -sudo apt-get install -y python3.9 python3.9-dev python3.9-distutils -sudo apt-get install -y python3.10 python3.10-dev python3.10-distutils -sudo apt-get install -y python3.11 python3.11-dev python3.11-distutils -sudo apt-get install -y python3.12 python3.12-dev python3.12-distutils +apt-get install -y python3.7 python3.7-dev python3.7-distutils +apt-get install -y python3.8 python3.8-dev python3.8-distutils +apt-get install -y python3.9 python3.9-dev python3.9-distutils +apt-get install -y python3.10 python3.10-dev python3.10-distutils +apt-get install -y python3.11 python3.11-dev python3.11-distutils +apt-get install -y python3.12 python3.12-dev python3.12-distutils # install build and testing tool +python -m ensurepip -U + python3.7 -m pip install -U pip setuptools wheel build python3.8 -m pip install -U pip setuptools wheel build python3.9 -m pip install -U pip setuptools wheel build python3.10 -m pip install -U pip setuptools wheel build python3.11 -m pip install -U pip setuptools wheel build +python3.12 -m pip install -U pip setuptools wheel build pip install ruff -sudo apt-get install -y tox clang-format +apt-get install -y tox clang-format pip install -U tox # install PostgreSQL client tools -sudo apt-get install -y postgresql libpq-dev +apt-get install -y postgresql libpq-dev for pghost in pg10 pg12 pg14 pg15 pg16 do From 0f18e8060c9037ac555c133f3c8ec80202d33492 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 27 Jul 2024 17:39:38 +0200 Subject: [PATCH 187/194] Fix doc for DB.delete (#87) --- docs/contents/pg/db_wrapper.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 1dbd18ef..b9e72b69 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -715,7 +715,7 @@ delete -- delete a row from a database table Delete a row from a database table :param str table: name of table - :param dict d: optional dictionary of values + :param dict row: optional dictionary of values :param col: optional keyword arguments for updating the dictionary :rtype: None :raises pg.ProgrammingError: table has no primary key, From a07a71d3fba3e3a653623f7274109f658fcd19a0 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 27 Jul 2024 17:38:35 +0000 Subject: [PATCH 188/194] Use newer mypy and ruff --- tests/test_classic_connection.py | 8 ++++---- tests/test_dbapi20.py | 2 +- tox.ini | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 3f9427b2..180563ed 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1997,10 +1997,10 @@ def test_inserttable_byte_values(self): row_bytes = tuple( s.encode() if isinstance(s, str) else s for s in row_unicode) - data = [row_bytes] * 2 - self.c.inserttable('test', data) - data = [row_unicode] * 2 - self.assertEqual(self.get_back(), data) + data_bytes = [row_bytes] * 2 + self.c.inserttable('test', data_bytes) + data_unicode = [row_unicode] * 2 + self.assertEqual(self.get_back(), data_unicode) def test_inserttable_unicode_utf8(self): try: diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index ef4857d3..0e70e073 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -161,7 +161,7 @@ def test_row_factory(self): class TestCursor(pgdb.Cursor): - def row_factory(self, row): + def row_factory(self, row): # type: ignore[override] description = self.description assert isinstance(description, list) return {f'column {desc[0]}': value diff --git a/tox.ini b/tox.ini index f703abb5..0679e456 100644 --- a/tox.ini +++ b/tox.ini @@ -5,13 +5,13 @@ envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.12 -deps = ruff>=0.3.7 +deps = ruff>=0.5,<0.6 commands = ruff check setup.py pg pgdb tests [testenv:mypy] basepython = python3.12 -deps = mypy>=1.9.0 +deps = mypy>=1.11,<1.12 commands = mypy pg pgdb tests @@ -43,7 +43,7 @@ basepython = python3.12 deps = coverage>=7,<8 commands = - coverage run -m unittest discover + coverage run -m unittest discover -v coverage html [testenv] @@ -54,4 +54,4 @@ deps = setuptools>=68 commands = python setup.py clean --all build_ext --force --inplace --strict --memory-size - python -m unittest {posargs:discover} + python -m unittest {posargs:discover -v} From 487452e988e212db426780c3851323629c27b55b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 5 Dec 2024 12:18:55 +0000 Subject: [PATCH 189/194] Update dependencies and supported versions --- .devcontainer/docker-compose.yml | 11 ++++ .devcontainer/provision.sh | 4 +- .github/workflows/docs.yml | 6 +- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 14 ++-- README.rst | 4 +- docs/about.rst | 4 +- docs/contents/install.rst | 2 +- pg/__init__.py | 108 +++++++++++++++++++++++-------- pg/adapt.py | 10 ++- pg/cast.py | 18 ++++-- pg/core.py | 92 +++++++++++++++++++------- pg/error.py | 7 +- pg/helpers.py | 10 ++- pgdb/__init__.py | 71 +++++++++++++++----- pgdb/adapt.py | 36 +++++++++-- pgdb/cast.py | 23 +++++-- pyproject.toml | 3 +- setup.py | 1 + tests/test_classic_connection.py | 2 +- tests/test_classic_largeobj.py | 15 ++--- tox.ini | 20 +++--- 22 files changed, 340 insertions(+), 123 deletions(-) diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 61b13a7c..541d63e9 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -61,9 +61,20 @@ services: POSTGRES_DB: postgres POSTGRES_PASSWORD: postgres + pg17: + image: postgres:17 + restart: unless-stopped + volumes: + - postgres-data-17:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + volumes: postgres-data-10: postgres-data-12: postgres-data-14: postgres-data-15: postgres-data-16: + postgres-data-17: diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 5515b687..1ca7b020 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -27,6 +27,7 @@ apt-get install -y python3.9 python3.9-dev python3.9-distutils apt-get install -y python3.10 python3.10-dev python3.10-distutils apt-get install -y python3.11 python3.11-dev python3.11-distutils apt-get install -y python3.12 python3.12-dev python3.12-distutils +apt-get install -y python3.13 python3.13-dev python3.13-distutils # install build and testing tool @@ -38,6 +39,7 @@ python3.9 -m pip install -U pip setuptools wheel build python3.10 -m pip install -U pip setuptools wheel build python3.11 -m pip install -U pip setuptools wheel build python3.12 -m pip install -U pip setuptools wheel build +python3.13 -m pip install -U pip setuptools wheel build pip install ruff @@ -48,7 +50,7 @@ pip install -U tox apt-get install -y postgresql libpq-dev -for pghost in pg10 pg12 pg14 pg15 pg16 +for pghost in pg10 pg12 pg14 pg15 pg16 pg17 do export PGHOST=$pghost export PGDATABASE=postgres diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 50248b64..d88cd64a 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -13,16 +13,16 @@ jobs: steps: - name: Check out repository uses: actions/checkout@v4 - - name: Set up Python 3.12 + - name: Set up Python 3.13 uses: actions/setup-python@v5 with: - python-version: 3.12 + python-version: 3.13 - name: Install dependencies run: | sudo apt install libpq-dev python -m pip install --upgrade pip pip install . - pip install "sphinx>=7,<8" + pip install "sphinx>=8,<9" - name: Create docs with Sphinx run: | cd docs diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9e5c0bde..66d79095 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,7 +20,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: 3.12 + python-version: 3.13 - name: Run quality checks run: tox -e ruff,mypy,cformat,docs timeout-minutes: 5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 822fabdb..920e3f3e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,14 +21,16 @@ jobs: - { python: "3.10", postgres: "14" } - { python: "3.11", postgres: "15" } - { python: "3.12", postgres: "16" } + - { python: "3.13", postgres: "17" } # Opposite extremes of the supported Py/PG range, other architecture - - { python: "3.7", postgres: "16", architecture: "x86" } - - { python: "3.8", postgres: "15", architecture: "x86" } - - { python: "3.9", postgres: "14", architecture: "x86" } - - { python: "3.10", postgres: "13", architecture: "x86" } - - { python: "3.11", postgres: "12", architecture: "x86" } - - { python: "3.12", postgres: "11", architecture: "x86" } + - { python: "3.7", postgres: "17", architecture: "x86" } + - { python: "3.8", postgres: "16", architecture: "x86" } + - { python: "3.9", postgres: "15", architecture: "x86" } + - { python: "3.10", postgres: "14", architecture: "x86" } + - { python: "3.11", postgres: "13", architecture: "x86" } + - { python: "3.12", postgres: "12", architecture: "x86" } + - { python: "3.13", postgres: "11", architecture: "x86" } env: PYGRESQL_DB: test diff --git a/README.rst b/README.rst index e9f9465c..46a09c2b 100644 --- a/README.rst +++ b/README.rst @@ -18,8 +18,8 @@ The following Python versions are supported: * PyGreSQL 5.x: Python 2 and Python 3 * PyGreSQL 6.x and newer: Python 3 only -The current version of PyGreSQL supports Python versions 3.7 to 3.12 -and PostgreSQL versions 10 to 16 on the server. +The current version of PyGreSQL supports Python versions 3.7 to 3.13 +and PostgreSQL versions 10 to 17 on the server. Installation ------------ diff --git a/docs/about.rst b/docs/about.rst index 180af459..ec1dbd2f 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -39,6 +39,6 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL |version| needs PostgreSQL 10 to 16, and Python -3.7 to 3.12. If you need to support older PostgreSQL or Python versions, +The current version PyGreSQL |version| needs PostgreSQL 10 to 17, and Python +3.7 to 3.13. If you need to support older PostgreSQL or Python versions, you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 7d28ea59..23694528 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -14,7 +14,7 @@ On Windows, you also need to make sure that the directory that contains ``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -3.7 to 3.12, and PostgreSQL versions 10 to 16. +3.7 to 3.13, and PostgreSQL versions 10 to 17. PyGreSQL will be installed as two packages named ``pg`` (for the classic interface) and ``pgdb`` (for the DB API 2 compliant interface). The former diff --git a/pg/__init__.py b/pg/__init__.py index cb4c7c34..eeda3b73 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -99,34 +99,86 @@ from .notify import NotificationHandler __all__ = [ - 'DB', 'Adapter', - 'NotificationHandler', 'Typecasts', - 'Bytea', 'Hstore', 'Json', 'Literal', - 'Error', 'Warning', - 'DataError', 'DatabaseError', - 'IntegrityError', 'InterfaceError', 'InternalError', - 'InvalidResultError', 'MultipleResultsError', - 'NoResultError', 'NotSupportedError', - 'OperationalError', 'ProgrammingError', - 'Connection', 'Query', 'RowCache', - 'INV_READ', 'INV_WRITE', - 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', - 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', - 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', - 'TRANS_ACTIVE', 'TRANS_IDLE', 'TRANS_INERROR', - 'TRANS_INTRANS', 'TRANS_UNKNOWN', - 'cast_array', 'cast_hstore', 'cast_record', - 'connect', 'escape_bytea', 'escape_string', 'unescape_bytea', - 'get_array', 'get_bool', 'get_bytea_escaped', - 'get_datestyle', 'get_decimal', 'get_decimal_point', - 'get_defbase', 'get_defhost', 'get_defopt', 'get_defport', 'get_defuser', - 'get_jsondecode', 'get_pqlib_version', 'get_typecast', - 'set_array', 'set_bool', 'set_bytea_escaped', - 'set_datestyle', 'set_decimal', 'set_decimal_point', - 'set_defbase', 'set_defhost', 'set_defopt', - 'set_defpasswd', 'set_defport', 'set_defuser', - 'set_jsondecode', 'set_query_helpers', 'set_typecast', - 'version', '__version__', + 'DB', + 'INV_READ', + 'INV_WRITE', + 'POLLING_FAILED', + 'POLLING_OK', + 'POLLING_READING', + 'POLLING_WRITING', + 'RESULT_DDL', + 'RESULT_DML', + 'RESULT_DQL', + 'RESULT_EMPTY', + 'SEEK_CUR', + 'SEEK_END', + 'SEEK_SET', + 'TRANS_ACTIVE', + 'TRANS_IDLE', + 'TRANS_INERROR', + 'TRANS_INTRANS', + 'TRANS_UNKNOWN', + 'Adapter', + 'Bytea', + 'Connection', + 'DataError', + 'DatabaseError', + 'Error', + 'Hstore', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'InvalidResultError', + 'Json', + 'Literal', + 'MultipleResultsError', + 'NoResultError', + 'NotSupportedError', + 'NotificationHandler', + 'OperationalError', + 'ProgrammingError', + 'Query', + 'RowCache', + 'Typecasts', + 'Warning', + '__version__', + 'cast_array', + 'cast_hstore', + 'cast_record', + 'connect', + 'escape_bytea', + 'escape_string', + 'get_array', + 'get_bool', + 'get_bytea_escaped', + 'get_datestyle', + 'get_decimal', + 'get_decimal_point', + 'get_defbase', + 'get_defhost', + 'get_defopt', + 'get_defport', + 'get_defuser', + 'get_jsondecode', + 'get_pqlib_version', + 'get_typecast', + 'set_array', + 'set_bool', + 'set_bytea_escaped', + 'set_datestyle', + 'set_decimal', + 'set_decimal_point', + 'set_defbase', + 'set_defhost', + 'set_defopt', + 'set_defpasswd', + 'set_defport', + 'set_defuser', + 'set_jsondecode', + 'set_query_helpers', + 'set_typecast', + 'unescape_bytea', + 'version', ] __version__ = version diff --git a/pg/adapt.py b/pg/adapt.py index 2a5efaa2..97e0391c 100644 --- a/pg/adapt.py +++ b/pg/adapt.py @@ -21,8 +21,14 @@ from .db import DB __all__ = [ - 'Adapter', 'Bytea', 'DbType', 'DbTypes', - 'Hstore', 'Literal', 'Json', 'UUID' + 'UUID', + 'Adapter', + 'Bytea', + 'DbType', + 'DbTypes', + 'Hstore', + 'Json', + 'Literal' ] diff --git a/pg/cast.py b/pg/cast.py index ad1758be..98baa8f6 100644 --- a/pg/cast.py +++ b/pg/cast.py @@ -25,10 +25,20 @@ from .tz import timezone_as_offset __all__ = [ - 'cast_bool', 'cast_json', 'cast_num', 'cast_money', 'cast_int2vector', - 'cast_date', 'cast_time', 'cast_timetz', 'cast_interval', - 'cast_timestamp','cast_timestamptz', - 'Typecasts', 'get_typecast', 'set_typecast' + 'Typecasts', + 'cast_bool', + 'cast_date', + 'cast_int2vector', + 'cast_interval', + 'cast_json', + 'cast_money', + 'cast_num', + 'cast_time', + 'cast_timestamp', + 'cast_timestamptz', + 'cast_timetz', + 'get_typecast', + 'set_typecast' ] def get_args(func: Callable) -> list: diff --git a/pg/core.py b/pg/core.py index e20bdbd0..4d0c03c0 100644 --- a/pg/core.py +++ b/pg/core.py @@ -108,29 +108,73 @@ ) __all__ = [ - 'Error', 'Warning', - 'DataError', 'DatabaseError', - 'IntegrityError', 'InterfaceError', 'InternalError', - 'InvalidResultError', 'MultipleResultsError', - 'NoResultError', 'NotSupportedError', - 'OperationalError', 'ProgrammingError', - 'Connection', 'Query', 'LargeObject', - 'INV_READ', 'INV_WRITE', - 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', - 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', - 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', - 'TRANS_ACTIVE', 'TRANS_IDLE', 'TRANS_INERROR', - 'TRANS_INTRANS', 'TRANS_UNKNOWN', - 'cast_array', 'cast_hstore', 'cast_record', - 'connect', 'escape_bytea', 'escape_string', 'unescape_bytea', - 'get_array', 'get_bool', 'get_bytea_escaped', - 'get_datestyle', 'get_decimal', 'get_decimal_point', - 'get_defbase', 'get_defhost', 'get_defopt', 'get_defport', 'get_defuser', - 'get_jsondecode', 'get_pqlib_version', - 'set_array', 'set_bool', 'set_bytea_escaped', - 'set_datestyle', 'set_decimal', 'set_decimal_point', - 'set_defbase', 'set_defhost', 'set_defopt', - 'set_defpasswd', 'set_defport', 'set_defuser', - 'set_jsondecode', 'set_query_helpers', + 'INV_READ', + 'INV_WRITE', + 'POLLING_FAILED', + 'POLLING_OK', + 'POLLING_READING', + 'POLLING_WRITING', + 'RESULT_DDL', + 'RESULT_DML', + 'RESULT_DQL', + 'RESULT_EMPTY', + 'SEEK_CUR', + 'SEEK_END', + 'SEEK_SET', + 'TRANS_ACTIVE', + 'TRANS_IDLE', + 'TRANS_INERROR', + 'TRANS_INTRANS', + 'TRANS_UNKNOWN', + 'Connection', + 'DataError', + 'DatabaseError', + 'Error', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'InvalidResultError', + 'LargeObject', + 'MultipleResultsError', + 'NoResultError', + 'NotSupportedError', + 'OperationalError', + 'ProgrammingError', + 'Query', + 'Warning', + 'cast_array', + 'cast_hstore', + 'cast_record', + 'connect', + 'escape_bytea', + 'escape_string', + 'get_array', + 'get_bool', + 'get_bytea_escaped', + 'get_datestyle', + 'get_decimal', + 'get_decimal_point', + 'get_defbase', + 'get_defhost', + 'get_defopt', + 'get_defport', + 'get_defuser', + 'get_jsondecode', + 'get_pqlib_version', + 'set_array', + 'set_bool', + 'set_bytea_escaped', + 'set_datestyle', + 'set_decimal', + 'set_decimal_point', + 'set_defbase', + 'set_defhost', + 'set_defopt', + 'set_defpasswd', + 'set_defport', + 'set_defuser', + 'set_jsondecode', + 'set_query_helpers', + 'unescape_bytea', 'version', ] diff --git a/pg/error.py b/pg/error.py index 484a1252..f4b9fd0f 100644 --- a/pg/error.py +++ b/pg/error.py @@ -14,7 +14,12 @@ ) __all__ = [ - 'error', 'db_error', 'if_error', 'int_error', 'op_error', 'prg_error' + 'db_error', + 'error', + 'if_error', + 'int_error', + 'op_error', + 'prg_error' ] # Error messages diff --git a/pg/helpers.py b/pg/helpers.py index 53689f6a..9d176740 100644 --- a/pg/helpers.py +++ b/pg/helpers.py @@ -13,8 +13,14 @@ SomeNamedTuple = Any # alias for accessing arbitrary named tuples __all__ = [ - 'quote_if_unqualified', 'oid_key', 'QuoteDict', 'RowCache', - 'dictiter', 'namediter', 'namednext', 'scalariter' + 'QuoteDict', + 'RowCache', + 'dictiter', + 'namediter', + 'namednext', + 'oid_key', + 'quote_if_unqualified', + 'scalariter' ] diff --git a/pgdb/__init__.py b/pgdb/__init__.py index 2604074a..5db2fd46 100644 --- a/pgdb/__init__.py +++ b/pgdb/__init__.py @@ -121,21 +121,62 @@ from .cursor import Cursor __all__ = [ - 'Connection', 'Cursor', - 'Date', 'Time', 'Timestamp', - 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks', - 'Binary', 'Interval', 'Uuid', - 'Hstore', 'Json', 'Literal', 'DbType', - 'STRING', 'BINARY', 'NUMBER', 'DATETIME', 'ROWID', 'BOOL', - 'SMALLINT', 'INTEGER', 'LONG', 'FLOAT', 'NUMERIC', 'MONEY', - 'DATE', 'TIME', 'TIMESTAMP', 'INTERVAL', - 'UUID', 'HSTORE', 'JSON', 'ARRAY', 'RECORD', - 'Error', 'Warning', - 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError', - 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', - 'get_typecast', 'set_typecast', 'reset_typecast', - 'apilevel', 'connect', 'paramstyle', 'shortcutmethods', 'threadsafety', - 'version', '__version__', + 'ARRAY', + 'BINARY', + 'BOOL', + 'DATE', + 'DATETIME', + 'FLOAT', + 'HSTORE', + 'INTEGER', + 'INTERVAL', + 'JSON', + 'LONG', + 'MONEY', + 'NUMBER', + 'NUMERIC', + 'RECORD', + 'ROWID', + 'SMALLINT', + 'STRING', + 'TIME', + 'TIMESTAMP', + 'UUID', + 'Binary', + 'Connection', + 'Cursor', + 'DataError', + 'DatabaseError', + 'Date', + 'DateFromTicks', + 'DbType', + 'Error', + 'Hstore', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'Interval', + 'Json', + 'Literal', + 'NotSupportedError', + 'OperationalError', + 'ProgrammingError', + 'Time', + 'TimeFromTicks', + 'Timestamp', + 'TimestampFromTicks', + 'Uuid', + 'Warning', + '__version__', + 'apilevel', + 'connect', + 'get_typecast', + 'paramstyle', + 'reset_typecast', + 'set_typecast', + 'shortcutmethods', + 'threadsafety', + 'version', ] __version__ = version diff --git a/pgdb/adapt.py b/pgdb/adapt.py index b89986b6..f657b190 100644 --- a/pgdb/adapt.py +++ b/pgdb/adapt.py @@ -12,12 +12,36 @@ from .typecode import TypeCode __all__ = [ - 'DbType', 'ArrayType', 'RecordType', - 'STRING', 'BINARY', 'NUMBER', 'DATETIME', 'ROWID', 'BOOL', 'SMALLINT', - 'INTEGER', 'LONG', 'FLOAT', 'NUMERIC', 'MONEY', 'DATE', 'TIME', - 'TIMESTAMP', 'INTERVAL', 'UUID', 'HSTORE', 'JSON', 'ARRAY', 'RECORD', - 'Date', 'Time', 'Timestamp', - 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks' + 'ARRAY', + 'BINARY', + 'BOOL', + 'DATE', + 'DATETIME', + 'FLOAT', + 'HSTORE', + 'INTEGER', + 'INTERVAL', + 'JSON', + 'LONG', + 'MONEY', + 'NUMBER', + 'NUMERIC', + 'RECORD', + 'ROWID', + 'SMALLINT', + 'STRING', + 'TIME', + 'TIMESTAMP', + 'UUID', + 'ArrayType', + 'Date', + 'DateFromTicks', + 'DbType', + 'RecordType', + 'Time', + 'TimeFromTicks', + 'Timestamp', + 'TimestampFromTicks' ] diff --git a/pgdb/cast.py b/pgdb/cast.py index 03367506..49b4bd84 100644 --- a/pgdb/cast.py +++ b/pgdb/cast.py @@ -24,11 +24,24 @@ from .typecode import TypeCode __all__ = [ - 'Decimal', 'decimal_type', 'cast_bool', 'cast_money', - 'cast_int2vector', 'cast_date', 'cast_time', 'cast_interval', - 'cast_timetz', 'cast_timestamp', 'cast_timestamptz', - 'get_typecast', 'set_typecast', 'reset_typecast', - 'Typecasts', 'LocalTypecasts', 'TypeCache', 'FieldInfo' + 'Decimal', + 'FieldInfo', + 'LocalTypecasts', + 'TypeCache', + 'Typecasts', + 'cast_bool', + 'cast_date', + 'cast_int2vector', + 'cast_interval', + 'cast_money', + 'cast_time', + 'cast_timestamp', + 'cast_timestamptz', + 'cast_timetz', + 'decimal_type', + 'get_typecast', + 'reset_typecast', + 'set_typecast' ] diff --git a/pyproject.toml b/pyproject.toml index ef1de2a6..3ef4d645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: SQL", "Topic :: Database", "Topic :: Database :: Front-Ends", @@ -78,7 +79,7 @@ ignore = ["D203", "D213"] "tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] [tool.mypy] -python_version = "3.12" +python_version = "3.13" check_untyped_defs = true no_implicit_optional = true strict_optional = true diff --git a/setup.py b/setup.py index 8b1ec5dc..950364c0 100755 --- a/setup.py +++ b/setup.py @@ -170,6 +170,7 @@ def finalize_options(self): 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Programming Language :: SQL', 'Topic :: Database', 'Topic :: Database :: Front-Ends', diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 180563ed..7234ffb6 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -984,7 +984,7 @@ def test_query_with_bool_params(self, bool_enabled=None): pg.set_bool(bool_enabled) try: bool_on = bool_enabled or bool_enabled is None - v_false, v_true = (False, True) if bool_on else 'ft' + v_false, v_true = (False, True) if bool_on else ('f', 't') r_false, r_true = [(v_false,)], [(v_true,)] self.assertEqual(query("select false").getresult(), r_false) self.assertEqual(query("select true").getresult(), r_true) diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index 4fb8773c..7c53053d 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -112,7 +112,7 @@ def test_lo_import(self): fname = 'temp_test_pg_largeobj_import.txt' f = open(fname, 'wb') # noqa: SIM115 else: - f = tempfile.NamedTemporaryFile() + f = tempfile.NamedTemporaryFile() # noqa: SIM115 fname = f.name data = b'some data to be imported' f.write(data) @@ -420,7 +420,7 @@ def test_export(self): fname = 'temp_test_pg_largeobj_export.txt' f = open(fname, 'wb') # noqa: SIM115 else: - f = tempfile.NamedTemporaryFile() + f = tempfile.NamedTemporaryFile() # noqa: SIM115 fname = f.name data = b'some data to be exported' self.obj.open(pg.INV_WRITE) @@ -441,12 +441,11 @@ def test_export(self): def test_export_in_existent(self): export = self.obj.export - f = tempfile.NamedTemporaryFile() - self.obj.open(pg.INV_WRITE) - self.obj.close() - self.pgcnx.query(f'select lo_unlink({self.obj.oid})') - self.assertRaises(IOError, export, f.name) - f.close() + with tempfile.NamedTemporaryFile() as f: + self.obj.open(pg.INV_WRITE) + self.obj.close() + self.pgcnx.query(f'select lo_unlink({self.obj.oid})') + self.assertRaises(IOError, export, f.name) if __name__ == '__main__': diff --git a/tox.ini b/tox.ini index 0679e456..e89c7d73 100644 --- a/tox.ini +++ b/tox.ini @@ -1,36 +1,36 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs +envlist = py3{7,8,9,10,11,12,13},ruff,mypy,cformat,docs [testenv:ruff] -basepython = python3.12 -deps = ruff>=0.5,<0.6 +basepython = python3.13 +deps = ruff>=0.8,<0.9 commands = ruff check setup.py pg pgdb tests [testenv:mypy] -basepython = python3.12 -deps = mypy>=1.11,<1.12 +basepython = python3.13 +deps = mypy>=1.13,<1.14 commands = mypy pg pgdb tests [testenv:cformat] -basepython = python3.12 +basepython = python3.13 allowlist_externals = sh commands = sh -c "! (clang-format --style=file -n ext/*.c 2>&1 | tee /dev/tty | grep format-violations)" [testenv:docs] -basepython = python3.12 +basepython = python3.13 deps = - sphinx>=7,<8 + sphinx>=8,<9 commands = sphinx-build -b html -nEW docs docs/_build/html [testenv:build] -basepython = python3.12 +basepython = python3.13 deps = setuptools>=68 wheel>=0.42,<1 @@ -39,7 +39,7 @@ commands = python -m build -s -n -C strict -C memory-size [testenv:coverage] -basepython = python3.12 +basepython = python3.13 deps = coverage>=7,<8 commands = From a29e5822c90a3a5fe1a86baf8571f89843c1afef Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 5 Dec 2024 14:10:44 +0000 Subject: [PATCH 190/194] Test should work with Pg 17 client and newer --- tests/test_classic_connection.py | 4 ++-- tests/test_classic_dbwrapper.py | 4 ++-- tests/test_classic_functions.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 7234ffb6..90d69a59 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -174,8 +174,8 @@ def test_attribute_protocol_version(self): def test_attribute_server_version(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertGreaterEqual(server_version, 100000) - self.assertLess(server_version, 170000) + self.assertGreaterEqual(server_version, 100000) # >= 10.0 + self.assertLess(server_version, 190000) # < 20.0 def test_attribute_socket(self): socket = self.connection.socket diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index f02955c7..1d64c754 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -168,8 +168,8 @@ def test_attribute_protocol_version(self): def test_attribute_server_version(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertGreaterEqual(server_version, 100000) - self.assertLess(server_version, 170000) + self.assertGreaterEqual(server_version, 100000) # >= 10.0 + self.assertLess(server_version, 200000) # < 20.0 self.assertEqual(server_version, self.db.db.server_version) def test_attribute_socket(self): diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 4351f794..d1bde01c 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -124,8 +124,8 @@ def test_pqlib_version(self): # noinspection PyUnresolvedReferences v = pg.get_pqlib_version() self.assertIsInstance(v, int) - self.assertGreater(v, 100000) - self.assertLess(v, 170000) + self.assertGreater(v, 100000) # >= 10.0 + self.assertLess(v, 200000) # < 20.0 class TestParseArray(unittest.TestCase): From 6ee4c4565bf20332656503ef9f3201cfec2eaa18 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 5 Dec 2024 14:13:32 +0000 Subject: [PATCH 191/194] Make tox work with Python 3.7 --- tox.ini | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tox.ini b/tox.ini index e89c7d73..2359c8df 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,9 @@ [tox] envlist = py3{7,8,9,10,11,12,13},ruff,mypy,cformat,docs +requires = # this is needed for compatibility with Python 3.7 + pip<24.1 + virtualenv<20.27 [testenv:ruff] basepython = python3.13 From 417f5430f5375550e6955e185b60742c52c861a0 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 5 Dec 2024 16:51:23 +0100 Subject: [PATCH 192/194] Bump minor version --- docs/contents/changelog.rst | 4 ++++ pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 6afcb1e8..ad5f7f0e 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,10 @@ ChangeLog ========= +Version 6.1.0 (2024-12-05) +-------------------------- +- Support Python 3.13 and PostgreSQL 17. + Version 6.0.1 (2024-04-19) -------------------------- - Properly adapt falsy JSON values (#86) diff --git a/pyproject.toml b/pyproject.toml index 3ef4d645..01b5086f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.0.1" +version = "6.1.0" requires-python = ">=3.7" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, From fae41b5cfd1c28a405839d9e3ed5938432041f64 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 5 Dec 2024 17:43:33 +0100 Subject: [PATCH 193/194] Make it compile with latest MSVSC --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 950364c0..bf652276 100755 --- a/setup.py +++ b/setup.py @@ -136,7 +136,7 @@ def finalize_options(self): define_macros.append(('MS_WIN64', None)) elif compiler == 'msvc': # Microsoft Visual C++ extra_compile_args[1:] = [ - '-J', '-W3', '-WX', + '-J', '-W3', '-WX', '-wd4391', '-Dinline=__inline'] # needed for MSVC 9 From ca4392e98febb5b0b50cb087afc3a09538a07d17 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 6 Jan 2025 16:02:14 +0000 Subject: [PATCH 194/194] Update year of copyright --- LICENSE.txt | 2 +- docs/about.rst | 2 +- docs/conf.py | 2 +- docs/copyright.rst | 2 +- ext/pgconn.c | 2 +- ext/pginternal.c | 2 +- ext/pglarge.c | 2 +- ext/pgmodule.c | 2 +- ext/pgnotice.c | 2 +- ext/pgquery.c | 2 +- ext/pgsource.c | 2 +- pg/__init__.py | 2 +- pgdb/__init__.py | 2 +- 13 files changed, 13 insertions(+), 13 deletions(-) diff --git a/LICENSE.txt b/LICENSE.txt index b34bf23b..e905706e 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -6,7 +6,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain -Further modifications copyright (c) 2009-2024 by the PyGreSQL Development Team +Further modifications copyright (c) 2009-2025 by the PyGreSQL Development Team PyGreSQL is released under the PostgreSQL License, a liberal Open Source license, similar to the BSD or MIT licenses: diff --git a/docs/about.rst b/docs/about.rst index ec1dbd2f..10ceaf59 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -8,7 +8,7 @@ powerful PostgreSQL features from Python. | This software is copyright © 1995, Pascal Andre. | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2024 by the PyGreSQL team. + | Further modifications are copyright © 2009-2025 by the PyGreSQL team. | For licensing details, see the full :doc:`copyright`. **PostgreSQL** is a highly scalable, SQL compliant, open source diff --git a/docs/conf.py b/docs/conf.py index 45a86cd4..f25d78e7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -8,7 +8,7 @@ project = 'PyGreSQL' author = 'The PyGreSQL team' -copyright = '2024, ' + author +copyright = '2025, ' + author def project_version(): with open('../pyproject.toml') as f: diff --git a/docs/copyright.rst b/docs/copyright.rst index 60739ef0..bf7d9b04 100644 --- a/docs/copyright.rst +++ b/docs/copyright.rst @@ -10,7 +10,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain (darcy@PyGreSQL.org) -Further modifications copyright (c) 2009-2024 by the PyGreSQL team. +Further modifications copyright (c) 2009-2025 by the PyGreSQL team. Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/ext/pgconn.c b/ext/pgconn.c index ddc958ea..783eaffc 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -3,7 +3,7 @@ * * The connection object - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pginternal.c b/ext/pginternal.c index 9b3952cc..25290950 100644 --- a/ext/pginternal.c +++ b/ext/pginternal.c @@ -3,7 +3,7 @@ * * Internal functions - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pglarge.c b/ext/pglarge.c index f19568c4..1b817b25 100644 --- a/ext/pglarge.c +++ b/ext/pglarge.c @@ -3,7 +3,7 @@ * * Large object support - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgmodule.c b/ext/pgmodule.c index 26b916d6..916adda2 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -3,7 +3,7 @@ * * This is the main file for the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgnotice.c b/ext/pgnotice.c index ca051d88..c56b249f 100644 --- a/ext/pgnotice.c +++ b/ext/pgnotice.c @@ -3,7 +3,7 @@ * * The notice object - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgquery.c b/ext/pgquery.c index fe5dda47..b87eba18 100644 --- a/ext/pgquery.c +++ b/ext/pgquery.c @@ -3,7 +3,7 @@ * * The query object - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgsource.c b/ext/pgsource.c index 4e197578..bbec2f86 100644 --- a/ext/pgsource.c +++ b/ext/pgsource.c @@ -3,7 +3,7 @@ * * The source object - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pg/__init__.py b/pg/__init__.py index eeda3b73..c3b7f4e9 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -4,7 +4,7 @@ # # This file contains the classic pg module. # -# Copyright (c) 2024 by the PyGreSQL Development Team +# Copyright (c) 2025 by the PyGreSQL Development Team # # The notification handler is based on pgnotify which is # Copyright (c) 2001 Ng Pheng Siong. All rights reserved. diff --git a/pgdb/__init__.py b/pgdb/__init__.py index 5db2fd46..132ce292 100644 --- a/pgdb/__init__.py +++ b/pgdb/__init__.py @@ -4,7 +4,7 @@ # # This file contains the DB-API 2 compatible pgdb module. # -# Copyright (c) 2024 by the PyGreSQL Development Team +# Copyright (c) 2025 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions.